You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

212 lines
6.5 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const {
  3. handleDefaultStreamResponseV2,
  4. formatChatHistory,
  5. } = require("../../helpers/chat/responses");
  6. const { MODEL_MAP } = require("../modelMap");
  7. const {
  8. LLMPerformanceMonitor,
  9. } = require("../../helpers/chat/LLMPerformanceMonitor");
  10. class OpenAiLLM {
  11. constructor(embedder = null, modelPreference = null) {
  12. if (!process.env.OPEN_AI_KEY) throw new Error("No OpenAI API key was set.");
  13. const { OpenAI: OpenAIApi } = require("openai");
  14. this.openai = new OpenAIApi({
  15. apiKey: process.env.OPEN_AI_KEY,
  16. });
  17. this.model = modelPreference || process.env.OPEN_MODEL_PREF || "gpt-4o";
  18. this.limits = {
  19. history: this.promptWindowLimit() * 0.15,
  20. system: this.promptWindowLimit() * 0.15,
  21. user: this.promptWindowLimit() * 0.7,
  22. };
  23. this.embedder = embedder ?? new NativeEmbedder();
  24. this.defaultTemp = 0.7;
  25. }
  26. /**
  27. * Check if the model is an o1 model.
  28. * @returns {boolean}
  29. */
  30. get isOTypeModel() {
  31. return this.model.startsWith("o");
  32. }
  33. #appendContext(contextTexts = []) {
  34. if (!contextTexts || !contextTexts.length) return "";
  35. return (
  36. "\nContext:\n" +
  37. contextTexts
  38. .map((text, i) => {
  39. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  40. })
  41. .join("")
  42. );
  43. }
  44. streamingEnabled() {
  45. // o3-mini is the only o-type model that supports streaming
  46. if (this.isOTypeModel && this.model !== "o3-mini") return false;
  47. return "streamGetChatCompletion" in this;
  48. }
  49. static promptWindowLimit(modelName) {
  50. return MODEL_MAP.openai[modelName] ?? 4_096;
  51. }
  52. promptWindowLimit() {
  53. return MODEL_MAP.openai[this.model] ?? 4_096;
  54. }
  55. // Short circuit if name has 'gpt' since we now fetch models from OpenAI API
  56. // via the user API key, so the model must be relevant and real.
  57. // and if somehow it is not, chat will fail but that is caught.
  58. // we don't want to hit the OpenAI api every chat because it will get spammed
  59. // and introduce latency for no reason.
  60. async isValidChatCompletionModel(modelName = "") {
  61. const isPreset =
  62. modelName.toLowerCase().includes("gpt") ||
  63. modelName.toLowerCase().startsWith("o");
  64. if (isPreset) return true;
  65. const model = await this.openai.models
  66. .retrieve(modelName)
  67. .then((modelObj) => modelObj)
  68. .catch(() => null);
  69. return !!model;
  70. }
  71. /**
  72. * Generates appropriate content array for a message + attachments.
  73. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  74. * @returns {string|object[]}
  75. */
  76. #generateContent({ userPrompt, attachments = [] }) {
  77. if (!attachments.length) {
  78. return userPrompt;
  79. }
  80. const content = [{ type: "text", text: userPrompt }];
  81. for (let attachment of attachments) {
  82. content.push({
  83. type: "image_url",
  84. image_url: {
  85. url: attachment.contentString,
  86. detail: "high",
  87. },
  88. });
  89. }
  90. return content.flat();
  91. }
  92. /**
  93. * Construct the user prompt for this model.
  94. * @param {{attachments: import("../../helpers").Attachment[]}} param0
  95. * @returns
  96. */
  97. constructPrompt({
  98. systemPrompt = "",
  99. contextTexts = [],
  100. chatHistory = [],
  101. userPrompt = "",
  102. attachments = [], // This is the specific attachment for only this prompt
  103. }) {
  104. // o1 Models do not support the "system" role
  105. // in order to combat this, we can use the "user" role as a replacement for now
  106. // https://community.openai.com/t/o1-models-do-not-support-system-role-in-chat-completion/953880
  107. const prompt = {
  108. role: this.isOTypeModel ? "user" : "system",
  109. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  110. };
  111. return [
  112. prompt,
  113. ...formatChatHistory(chatHistory, this.#generateContent),
  114. {
  115. role: "user",
  116. content: this.#generateContent({ userPrompt, attachments }),
  117. },
  118. ];
  119. }
  120. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  121. if (!(await this.isValidChatCompletionModel(this.model)))
  122. throw new Error(
  123. `OpenAI chat: ${this.model} is not valid for chat completion!`
  124. );
  125. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  126. this.openai.chat.completions
  127. .create({
  128. model: this.model,
  129. messages,
  130. temperature: this.isOTypeModel ? 1 : temperature, // o1 models only accept temperature 1
  131. })
  132. .catch((e) => {
  133. throw new Error(e.message);
  134. })
  135. );
  136. if (
  137. !result.output.hasOwnProperty("choices") ||
  138. result.output.choices.length === 0
  139. )
  140. return null;
  141. return {
  142. textResponse: result.output.choices[0].message.content,
  143. metrics: {
  144. prompt_tokens: result.output.usage.prompt_tokens || 0,
  145. completion_tokens: result.output.usage.completion_tokens || 0,
  146. total_tokens: result.output.usage.total_tokens || 0,
  147. outputTps: result.output.usage.completion_tokens / result.duration,
  148. duration: result.duration,
  149. },
  150. };
  151. }
  152. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  153. if (!(await this.isValidChatCompletionModel(this.model)))
  154. throw new Error(
  155. `OpenAI chat: ${this.model} is not valid for chat completion!`
  156. );
  157. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  158. this.openai.chat.completions.create({
  159. model: this.model,
  160. stream: true,
  161. messages,
  162. temperature: this.isOTypeModel ? 1 : temperature, // o1 models only accept temperature 1
  163. }),
  164. messages
  165. // runPromptTokenCalculation: true - We manually count the tokens because OpenAI does not provide them in the stream
  166. // since we are not using the OpenAI API version that supports this `stream_options` param.
  167. );
  168. return measuredStreamRequest;
  169. }
  170. handleStream(response, stream, responseProps) {
  171. return handleDefaultStreamResponseV2(response, stream, responseProps);
  172. }
  173. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  174. async embedTextInput(textInput) {
  175. return await this.embedder.embedTextInput(textInput);
  176. }
  177. async embedChunks(textChunks = []) {
  178. return await this.embedder.embedChunks(textChunks);
  179. }
  180. async compressMessages(promptArgs = {}, rawHistory = []) {
  181. const { messageArrayCompressor } = require("../../helpers/chat");
  182. const messageArray = this.constructPrompt(promptArgs);
  183. return await messageArrayCompressor(this, messageArray, rawHistory);
  184. }
  185. }
  186. module.exports = {
  187. OpenAiLLM,
  188. };