You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
4.6 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const {
  3. LLMPerformanceMonitor,
  4. } = require("../../helpers/chat/LLMPerformanceMonitor");
  5. const {
  6. handleDefaultStreamResponseV2,
  7. } = require("../../helpers/chat/responses");
  8. class HuggingFaceLLM {
  9. constructor(embedder = null, _modelPreference = null) {
  10. if (!process.env.HUGGING_FACE_LLM_ENDPOINT)
  11. throw new Error("No HuggingFace Inference Endpoint was set.");
  12. if (!process.env.HUGGING_FACE_LLM_API_KEY)
  13. throw new Error("No HuggingFace Access Token was set.");
  14. const { OpenAI: OpenAIApi } = require("openai");
  15. this.openai = new OpenAIApi({
  16. baseURL: `${process.env.HUGGING_FACE_LLM_ENDPOINT}/v1`,
  17. apiKey: process.env.HUGGING_FACE_LLM_API_KEY,
  18. });
  19. // When using HF inference server - the model param is not required so
  20. // we can stub it here. HF Endpoints can only run one model at a time.
  21. // We set to 'tgi' so that endpoint for HF can accept message format
  22. this.model = "tgi";
  23. this.limits = {
  24. history: this.promptWindowLimit() * 0.15,
  25. system: this.promptWindowLimit() * 0.15,
  26. user: this.promptWindowLimit() * 0.7,
  27. };
  28. this.embedder = embedder ?? new NativeEmbedder();
  29. this.defaultTemp = 0.2;
  30. }
  31. #appendContext(contextTexts = []) {
  32. if (!contextTexts || !contextTexts.length) return "";
  33. return (
  34. "\nContext:\n" +
  35. contextTexts
  36. .map((text, i) => {
  37. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  38. })
  39. .join("")
  40. );
  41. }
  42. streamingEnabled() {
  43. return "streamGetChatCompletion" in this;
  44. }
  45. static promptWindowLimit(_modelName) {
  46. const limit = process.env.HUGGING_FACE_LLM_TOKEN_LIMIT || 4096;
  47. if (!limit || isNaN(Number(limit)))
  48. throw new Error("No HuggingFace token context limit was set.");
  49. return Number(limit);
  50. }
  51. promptWindowLimit() {
  52. const limit = process.env.HUGGING_FACE_LLM_TOKEN_LIMIT || 4096;
  53. if (!limit || isNaN(Number(limit)))
  54. throw new Error("No HuggingFace token context limit was set.");
  55. return Number(limit);
  56. }
  57. async isValidChatCompletionModel(_ = "") {
  58. return true;
  59. }
  60. constructPrompt({
  61. systemPrompt = "",
  62. contextTexts = [],
  63. chatHistory = [],
  64. userPrompt = "",
  65. }) {
  66. // System prompt it not enabled for HF model chats
  67. const prompt = {
  68. role: "user",
  69. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  70. };
  71. const assistantResponse = {
  72. role: "assistant",
  73. content: "Okay, I will follow those instructions",
  74. };
  75. return [
  76. prompt,
  77. assistantResponse,
  78. ...chatHistory,
  79. { role: "user", content: userPrompt },
  80. ];
  81. }
  82. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  83. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  84. this.openai.chat.completions
  85. .create({
  86. model: this.model,
  87. messages,
  88. temperature,
  89. })
  90. .catch((e) => {
  91. throw new Error(e.message);
  92. })
  93. );
  94. if (
  95. !result.output.hasOwnProperty("choices") ||
  96. result.output.choices.length === 0
  97. )
  98. return null;
  99. return {
  100. textResponse: result.output.choices[0].message.content,
  101. metrics: {
  102. prompt_tokens: result.output.usage?.prompt_tokens || 0,
  103. completion_tokens: result.output.usage?.completion_tokens || 0,
  104. total_tokens: result.output.usage?.total_tokens || 0,
  105. outputTps:
  106. (result.output.usage?.completion_tokens || 0) / result.duration,
  107. duration: result.duration,
  108. },
  109. };
  110. }
  111. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  112. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  113. this.openai.chat.completions.create({
  114. model: this.model,
  115. stream: true,
  116. messages,
  117. temperature,
  118. }),
  119. messages
  120. );
  121. return measuredStreamRequest;
  122. }
  123. handleStream(response, stream, responseProps) {
  124. return handleDefaultStreamResponseV2(response, stream, responseProps);
  125. }
  126. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  127. async embedTextInput(textInput) {
  128. return await this.embedder.embedTextInput(textInput);
  129. }
  130. async embedChunks(textChunks = []) {
  131. return await this.embedder.embedChunks(textChunks);
  132. }
  133. async compressMessages(promptArgs = {}, rawHistory = []) {
  134. const { messageArrayCompressor } = require("../../helpers/chat");
  135. const messageArray = this.constructPrompt(promptArgs);
  136. return await messageArrayCompressor(this, messageArray, rawHistory);
  137. }
  138. }
  139. module.exports = {
  140. HuggingFaceLLM,
  141. };