You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

198 lines
5.9 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const {
  3. LLMPerformanceMonitor,
  4. } = require("../../helpers/chat/LLMPerformanceMonitor");
  5. const {
  6. handleDefaultStreamResponseV2,
  7. formatChatHistory,
  8. } = require("../../helpers/chat/responses");
  9. class LiteLLM {
  10. constructor(embedder = null, modelPreference = null) {
  11. const { OpenAI: OpenAIApi } = require("openai");
  12. if (!process.env.LITE_LLM_BASE_PATH)
  13. throw new Error(
  14. "LiteLLM must have a valid base path to use for the api."
  15. );
  16. this.basePath = process.env.LITE_LLM_BASE_PATH;
  17. this.openai = new OpenAIApi({
  18. baseURL: this.basePath,
  19. apiKey: process.env.LITE_LLM_API_KEY ?? null,
  20. });
  21. this.model = modelPreference ?? process.env.LITE_LLM_MODEL_PREF ?? null;
  22. this.maxTokens = process.env.LITE_LLM_MODEL_TOKEN_LIMIT ?? 1024;
  23. if (!this.model) throw new Error("LiteLLM must have a valid model set.");
  24. this.limits = {
  25. history: this.promptWindowLimit() * 0.15,
  26. system: this.promptWindowLimit() * 0.15,
  27. user: this.promptWindowLimit() * 0.7,
  28. };
  29. this.embedder = embedder ?? new NativeEmbedder();
  30. this.defaultTemp = 0.7;
  31. this.log(`Inference API: ${this.basePath} Model: ${this.model}`);
  32. }
  33. log(text, ...args) {
  34. console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
  35. }
  36. #appendContext(contextTexts = []) {
  37. if (!contextTexts || !contextTexts.length) return "";
  38. return (
  39. "\nContext:\n" +
  40. contextTexts
  41. .map((text, i) => {
  42. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  43. })
  44. .join("")
  45. );
  46. }
  47. streamingEnabled() {
  48. return "streamGetChatCompletion" in this;
  49. }
  50. static promptWindowLimit(_modelName) {
  51. const limit = process.env.LITE_LLM_MODEL_TOKEN_LIMIT || 4096;
  52. if (!limit || isNaN(Number(limit)))
  53. throw new Error("No token context limit was set.");
  54. return Number(limit);
  55. }
  56. // Ensure the user set a value for the token limit
  57. // and if undefined - assume 4096 window.
  58. promptWindowLimit() {
  59. const limit = process.env.LITE_LLM_MODEL_TOKEN_LIMIT || 4096;
  60. if (!limit || isNaN(Number(limit)))
  61. throw new Error("No token context limit was set.");
  62. return Number(limit);
  63. }
  64. // Short circuit since we have no idea if the model is valid or not
  65. // in pre-flight for generic endpoints
  66. isValidChatCompletionModel(_modelName = "") {
  67. return true;
  68. }
  69. /**
  70. * Generates appropriate content array for a message + attachments.
  71. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  72. * @returns {string|object[]}
  73. */
  74. #generateContent({ userPrompt, attachments = [] }) {
  75. if (!attachments.length) {
  76. return userPrompt;
  77. }
  78. const content = [{ type: "text", text: userPrompt }];
  79. for (let attachment of attachments) {
  80. content.push({
  81. type: "image_url",
  82. image_url: {
  83. url: attachment.contentString,
  84. },
  85. });
  86. }
  87. return content.flat();
  88. }
  89. /**
  90. * Construct the user prompt for this model.
  91. * @param {{attachments: import("../../helpers").Attachment[]}} param0
  92. * @returns
  93. */
  94. constructPrompt({
  95. systemPrompt = "",
  96. contextTexts = [],
  97. chatHistory = [],
  98. userPrompt = "",
  99. attachments = [],
  100. }) {
  101. const prompt = {
  102. role: "system",
  103. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  104. };
  105. return [
  106. prompt,
  107. ...formatChatHistory(chatHistory, this.#generateContent),
  108. {
  109. role: "user",
  110. content: this.#generateContent({ userPrompt, attachments }),
  111. },
  112. ];
  113. }
  114. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  115. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  116. this.openai.chat.completions
  117. .create({
  118. model: this.model,
  119. messages,
  120. temperature,
  121. max_tokens: parseInt(this.maxTokens), // LiteLLM requires int
  122. })
  123. .catch((e) => {
  124. throw new Error(e.message);
  125. })
  126. );
  127. if (
  128. !result.output.hasOwnProperty("choices") ||
  129. result.output.choices.length === 0
  130. )
  131. return null;
  132. return {
  133. textResponse: result.output.choices[0].message.content,
  134. metrics: {
  135. prompt_tokens: result.output.usage?.prompt_tokens || 0,
  136. completion_tokens: result.output.usage?.completion_tokens || 0,
  137. total_tokens: result.output.usage?.total_tokens || 0,
  138. outputTps:
  139. (result.output.usage?.completion_tokens || 0) / result.duration,
  140. duration: result.duration,
  141. },
  142. };
  143. }
  144. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  145. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  146. this.openai.chat.completions.create({
  147. model: this.model,
  148. stream: true,
  149. messages,
  150. temperature,
  151. max_tokens: parseInt(this.maxTokens), // LiteLLM requires int
  152. }),
  153. messages
  154. // runPromptTokenCalculation: true - We manually count the tokens because they may or may not be provided in the stream
  155. // responses depending on LLM connected. If they are provided, then we counted for nothing, but better than nothing.
  156. );
  157. return measuredStreamRequest;
  158. }
  159. handleStream(response, stream, responseProps) {
  160. return handleDefaultStreamResponseV2(response, stream, responseProps);
  161. }
  162. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  163. async embedTextInput(textInput) {
  164. return await this.embedder.embedTextInput(textInput);
  165. }
  166. async embedChunks(textChunks = []) {
  167. return await this.embedder.embedChunks(textChunks);
  168. }
  169. async compressMessages(promptArgs = {}, rawHistory = []) {
  170. const { messageArrayCompressor } = require("../../helpers/chat");
  171. const messageArray = this.constructPrompt(promptArgs);
  172. return await messageArrayCompressor(this, messageArray, rawHistory);
  173. }
  174. }
  175. module.exports = {
  176. LiteLLM,
  177. };