You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

185 lines
5.1 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const {
  3. LLMPerformanceMonitor,
  4. } = require("../../helpers/chat/LLMPerformanceMonitor");
  5. const {
  6. handleDefaultStreamResponseV2,
  7. formatChatHistory,
  8. } = require("../../helpers/chat/responses");
  9. class MistralLLM {
  10. constructor(embedder = null, modelPreference = null) {
  11. if (!process.env.MISTRAL_API_KEY)
  12. throw new Error("No Mistral API key was set.");
  13. const { OpenAI: OpenAIApi } = require("openai");
  14. this.openai = new OpenAIApi({
  15. baseURL: "https://api.mistral.ai/v1",
  16. apiKey: process.env.MISTRAL_API_KEY ?? null,
  17. });
  18. this.model =
  19. modelPreference || process.env.MISTRAL_MODEL_PREF || "mistral-tiny";
  20. this.limits = {
  21. history: this.promptWindowLimit() * 0.15,
  22. system: this.promptWindowLimit() * 0.15,
  23. user: this.promptWindowLimit() * 0.7,
  24. };
  25. this.embedder = embedder ?? new NativeEmbedder();
  26. this.defaultTemp = 0.0;
  27. this.log("Initialized with model:", this.model);
  28. }
  29. log(text, ...args) {
  30. console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
  31. }
  32. #appendContext(contextTexts = []) {
  33. if (!contextTexts || !contextTexts.length) return "";
  34. return (
  35. "\nContext:\n" +
  36. contextTexts
  37. .map((text, i) => {
  38. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  39. })
  40. .join("")
  41. );
  42. }
  43. streamingEnabled() {
  44. return "streamGetChatCompletion" in this;
  45. }
  46. static promptWindowLimit() {
  47. return 32000;
  48. }
  49. promptWindowLimit() {
  50. return 32000;
  51. }
  52. async isValidChatCompletionModel(modelName = "") {
  53. return true;
  54. }
  55. /**
  56. * Generates appropriate content array for a message + attachments.
  57. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  58. * @returns {string|object[]}
  59. */
  60. #generateContent({ userPrompt, attachments = [] }) {
  61. if (!attachments.length) return userPrompt;
  62. const content = [{ type: "text", text: userPrompt }];
  63. for (let attachment of attachments) {
  64. content.push({
  65. type: "image_url",
  66. image_url: attachment.contentString,
  67. });
  68. }
  69. return content.flat();
  70. }
  71. /**
  72. * Construct the user prompt for this model.
  73. * @param {{attachments: import("../../helpers").Attachment[]}} param0
  74. * @returns
  75. */
  76. constructPrompt({
  77. systemPrompt = "",
  78. contextTexts = [],
  79. chatHistory = [],
  80. userPrompt = "",
  81. attachments = [], // This is the specific attachment for only this prompt
  82. }) {
  83. const prompt = {
  84. role: "system",
  85. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  86. };
  87. return [
  88. prompt,
  89. ...formatChatHistory(chatHistory, this.#generateContent),
  90. {
  91. role: "user",
  92. content: this.#generateContent({ userPrompt, attachments }),
  93. },
  94. ];
  95. }
  96. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  97. if (!(await this.isValidChatCompletionModel(this.model)))
  98. throw new Error(
  99. `Mistral chat: ${this.model} is not valid for chat completion!`
  100. );
  101. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  102. this.openai.chat.completions
  103. .create({
  104. model: this.model,
  105. messages,
  106. temperature,
  107. })
  108. .catch((e) => {
  109. throw new Error(e.message);
  110. })
  111. );
  112. if (
  113. !result.output.hasOwnProperty("choices") ||
  114. result.output.choices.length === 0
  115. )
  116. return null;
  117. return {
  118. textResponse: result.output.choices[0].message.content,
  119. metrics: {
  120. prompt_tokens: result.output.usage.prompt_tokens || 0,
  121. completion_tokens: result.output.usage.completion_tokens || 0,
  122. total_tokens: result.output.usage.total_tokens || 0,
  123. outputTps: result.output.usage.completion_tokens / result.duration,
  124. duration: result.duration,
  125. },
  126. };
  127. }
  128. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  129. if (!(await this.isValidChatCompletionModel(this.model)))
  130. throw new Error(
  131. `Mistral chat: ${this.model} is not valid for chat completion!`
  132. );
  133. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  134. this.openai.chat.completions.create({
  135. model: this.model,
  136. stream: true,
  137. messages,
  138. temperature,
  139. }),
  140. messages,
  141. false
  142. );
  143. return measuredStreamRequest;
  144. }
  145. handleStream(response, stream, responseProps) {
  146. return handleDefaultStreamResponseV2(response, stream, responseProps);
  147. }
  148. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  149. async embedTextInput(textInput) {
  150. return await this.embedder.embedTextInput(textInput);
  151. }
  152. async embedChunks(textChunks = []) {
  153. return await this.embedder.embedChunks(textChunks);
  154. }
  155. async compressMessages(promptArgs = {}, rawHistory = []) {
  156. const { messageArrayCompressor } = require("../../helpers/chat");
  157. const messageArray = this.constructPrompt(promptArgs);
  158. return await messageArrayCompressor(this, messageArray, rawHistory);
  159. }
  160. }
  161. module.exports = {
  162. MistralLLM,
  163. };