You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

192 lines
5.2 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const {
  3. LLMPerformanceMonitor,
  4. } = require("../../helpers/chat/LLMPerformanceMonitor");
  5. const {
  6. handleDefaultStreamResponseV2,
  7. formatChatHistory,
  8. } = require("../../helpers/chat/responses");
  9. const { MODEL_MAP } = require("../modelMap");
  10. class XAiLLM {
  11. constructor(embedder = null, modelPreference = null) {
  12. if (!process.env.XAI_LLM_API_KEY)
  13. throw new Error("No xAI API key was set.");
  14. const { OpenAI: OpenAIApi } = require("openai");
  15. this.openai = new OpenAIApi({
  16. baseURL: "https://api.x.ai/v1",
  17. apiKey: process.env.XAI_LLM_API_KEY,
  18. });
  19. this.model =
  20. modelPreference || process.env.XAI_LLM_MODEL_PREF || "grok-beta";
  21. this.limits = {
  22. history: this.promptWindowLimit() * 0.15,
  23. system: this.promptWindowLimit() * 0.15,
  24. user: this.promptWindowLimit() * 0.7,
  25. };
  26. this.embedder = embedder ?? new NativeEmbedder();
  27. this.defaultTemp = 0.7;
  28. this.log("Initialized with model:", this.model);
  29. }
  30. log(text, ...args) {
  31. console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
  32. }
  33. #appendContext(contextTexts = []) {
  34. if (!contextTexts || !contextTexts.length) return "";
  35. return (
  36. "\nContext:\n" +
  37. contextTexts
  38. .map((text, i) => {
  39. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  40. })
  41. .join("")
  42. );
  43. }
  44. streamingEnabled() {
  45. return "streamGetChatCompletion" in this;
  46. }
  47. static promptWindowLimit(modelName) {
  48. return MODEL_MAP.xai[modelName] ?? 131_072;
  49. }
  50. promptWindowLimit() {
  51. return MODEL_MAP.xai[this.model] ?? 131_072;
  52. }
  53. isValidChatCompletionModel(_modelName = "") {
  54. return true;
  55. }
  56. /**
  57. * Generates appropriate content array for a message + attachments.
  58. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  59. * @returns {string|object[]}
  60. */
  61. #generateContent({ userPrompt, attachments = [] }) {
  62. if (!attachments.length) {
  63. return userPrompt;
  64. }
  65. const content = [{ type: "text", text: userPrompt }];
  66. for (let attachment of attachments) {
  67. content.push({
  68. type: "image_url",
  69. image_url: {
  70. url: attachment.contentString,
  71. detail: "high",
  72. },
  73. });
  74. }
  75. return content.flat();
  76. }
  77. /**
  78. * Construct the user prompt for this model.
  79. * @param {{attachments: import("../../helpers").Attachment[]}} param0
  80. * @returns
  81. */
  82. constructPrompt({
  83. systemPrompt = "",
  84. contextTexts = [],
  85. chatHistory = [],
  86. userPrompt = "",
  87. attachments = [], // This is the specific attachment for only this prompt
  88. }) {
  89. const prompt = {
  90. role: "system",
  91. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  92. };
  93. return [
  94. prompt,
  95. ...formatChatHistory(chatHistory, this.#generateContent),
  96. {
  97. role: "user",
  98. content: this.#generateContent({ userPrompt, attachments }),
  99. },
  100. ];
  101. }
  102. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  103. if (!this.isValidChatCompletionModel(this.model))
  104. throw new Error(
  105. `xAI chat: ${this.model} is not valid for chat completion!`
  106. );
  107. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  108. this.openai.chat.completions
  109. .create({
  110. model: this.model,
  111. messages,
  112. temperature,
  113. })
  114. .catch((e) => {
  115. throw new Error(e.message);
  116. })
  117. );
  118. if (
  119. !result.output.hasOwnProperty("choices") ||
  120. result.output.choices.length === 0
  121. )
  122. return null;
  123. return {
  124. textResponse: result.output.choices[0].message.content,
  125. metrics: {
  126. prompt_tokens: result.output.usage.prompt_tokens || 0,
  127. completion_tokens: result.output.usage.completion_tokens || 0,
  128. total_tokens: result.output.usage.total_tokens || 0,
  129. outputTps: result.output.usage.completion_tokens / result.duration,
  130. duration: result.duration,
  131. },
  132. };
  133. }
  134. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  135. if (!this.isValidChatCompletionModel(this.model))
  136. throw new Error(
  137. `xAI chat: ${this.model} is not valid for chat completion!`
  138. );
  139. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  140. this.openai.chat.completions.create({
  141. model: this.model,
  142. stream: true,
  143. messages,
  144. temperature,
  145. }),
  146. messages,
  147. false
  148. );
  149. return measuredStreamRequest;
  150. }
  151. handleStream(response, stream, responseProps) {
  152. return handleDefaultStreamResponseV2(response, stream, responseProps);
  153. }
  154. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  155. async embedTextInput(textInput) {
  156. return await this.embedder.embedTextInput(textInput);
  157. }
  158. async embedChunks(textChunks = []) {
  159. return await this.embedder.embedChunks(textChunks);
  160. }
  161. async compressMessages(promptArgs = {}, rawHistory = []) {
  162. const { messageArrayCompressor } = require("../../helpers/chat");
  163. const messageArray = this.constructPrompt(promptArgs);
  164. return await messageArrayCompressor(this, messageArray, rawHistory);
  165. }
  166. }
  167. module.exports = {
  168. XAiLLM,
  169. };