You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
7.3 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const {
  3. handleDefaultStreamResponseV2,
  4. } = require("../../helpers/chat/responses");
  5. const {
  6. LLMPerformanceMonitor,
  7. } = require("../../helpers/chat/LLMPerformanceMonitor");
  8. const fs = require("fs");
  9. const path = require("path");
  10. const { safeJsonParse } = require("../../http");
  11. const cacheFolder = path.resolve(
  12. process.env.STORAGE_DIR
  13. ? path.resolve(process.env.STORAGE_DIR, "models", "togetherAi")
  14. : path.resolve(__dirname, `../../../storage/models/togetherAi`)
  15. );
  16. async function togetherAiModels(apiKey = null) {
  17. const cacheModelPath = path.resolve(cacheFolder, "models.json");
  18. const cacheAtPath = path.resolve(cacheFolder, ".cached_at");
  19. // If cache exists and is less than 1 week old, use it
  20. if (fs.existsSync(cacheModelPath) && fs.existsSync(cacheAtPath)) {
  21. const now = Number(new Date());
  22. const timestampMs = Number(fs.readFileSync(cacheAtPath));
  23. if (now - timestampMs <= 6.048e8) {
  24. // 1 Week in MS
  25. return safeJsonParse(
  26. fs.readFileSync(cacheModelPath, { encoding: "utf-8" }),
  27. []
  28. );
  29. }
  30. }
  31. try {
  32. const { OpenAI: OpenAIApi } = require("openai");
  33. const openai = new OpenAIApi({
  34. baseURL: "https://api.together.xyz/v1",
  35. apiKey: apiKey || process.env.TOGETHER_AI_API_KEY || null,
  36. });
  37. const response = await openai.models.list();
  38. // Filter and transform models into the expected format
  39. // Only include chat models
  40. const validModels = response.body
  41. .filter((model) => ["chat"].includes(model.type))
  42. .map((model) => ({
  43. id: model.id,
  44. name: model.display_name || model.id,
  45. organization: model.organization || "Unknown",
  46. type: model.type,
  47. maxLength: model.context_length || 4096,
  48. }));
  49. // Cache the results
  50. if (!fs.existsSync(cacheFolder))
  51. fs.mkdirSync(cacheFolder, { recursive: true });
  52. fs.writeFileSync(cacheModelPath, JSON.stringify(validModels), {
  53. encoding: "utf-8",
  54. });
  55. fs.writeFileSync(cacheAtPath, String(Number(new Date())), {
  56. encoding: "utf-8",
  57. });
  58. return validModels;
  59. } catch (error) {
  60. console.error("Error fetching Together AI models:", error);
  61. // If cache exists but is stale, still use it as fallback
  62. if (fs.existsSync(cacheModelPath)) {
  63. return safeJsonParse(
  64. fs.readFileSync(cacheModelPath, { encoding: "utf-8" }),
  65. []
  66. );
  67. }
  68. return [];
  69. }
  70. }
  71. class TogetherAiLLM {
  72. constructor(embedder = null, modelPreference = null) {
  73. if (!process.env.TOGETHER_AI_API_KEY)
  74. throw new Error("No TogetherAI API key was set.");
  75. const { OpenAI: OpenAIApi } = require("openai");
  76. this.openai = new OpenAIApi({
  77. baseURL: "https://api.together.xyz/v1",
  78. apiKey: process.env.TOGETHER_AI_API_KEY ?? null,
  79. });
  80. this.model = modelPreference || process.env.TOGETHER_AI_MODEL_PREF;
  81. this.limits = {
  82. history: this.promptWindowLimit() * 0.15,
  83. system: this.promptWindowLimit() * 0.15,
  84. user: this.promptWindowLimit() * 0.7,
  85. };
  86. this.embedder = !embedder ? new NativeEmbedder() : embedder;
  87. this.defaultTemp = 0.7;
  88. }
  89. #appendContext(contextTexts = []) {
  90. if (!contextTexts || !contextTexts.length) return "";
  91. return (
  92. "\nContext:\n" +
  93. contextTexts
  94. .map((text, i) => {
  95. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  96. })
  97. .join("")
  98. );
  99. }
  100. #generateContent({ userPrompt, attachments = [] }) {
  101. if (!attachments.length) {
  102. return userPrompt;
  103. }
  104. const content = [{ type: "text", text: userPrompt }];
  105. for (let attachment of attachments) {
  106. content.push({
  107. type: "image_url",
  108. image_url: {
  109. url: attachment.contentString,
  110. },
  111. });
  112. }
  113. return content.flat();
  114. }
  115. async allModelInformation() {
  116. const models = await togetherAiModels();
  117. return models.reduce((acc, model) => {
  118. acc[model.id] = model;
  119. return acc;
  120. }, {});
  121. }
  122. streamingEnabled() {
  123. return "streamGetChatCompletion" in this;
  124. }
  125. static async promptWindowLimit(modelName) {
  126. const models = await togetherAiModels();
  127. const model = models.find((m) => m.id === modelName);
  128. return model?.maxLength || 4096;
  129. }
  130. async promptWindowLimit() {
  131. const models = await togetherAiModels();
  132. const model = models.find((m) => m.id === this.model);
  133. return model?.maxLength || 4096;
  134. }
  135. async isValidChatCompletionModel(model = "") {
  136. const models = await togetherAiModels();
  137. const foundModel = models.find((m) => m.id === model);
  138. return foundModel && foundModel.type === "chat";
  139. }
  140. constructPrompt({
  141. systemPrompt = "",
  142. contextTexts = [],
  143. chatHistory = [],
  144. userPrompt = "",
  145. attachments = [],
  146. }) {
  147. const prompt = {
  148. role: "system",
  149. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  150. };
  151. return [
  152. prompt,
  153. ...chatHistory,
  154. {
  155. role: "user",
  156. content: this.#generateContent({ userPrompt, attachments }),
  157. },
  158. ];
  159. }
  160. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  161. if (!(await this.isValidChatCompletionModel(this.model)))
  162. throw new Error(
  163. `TogetherAI chat: ${this.model} is not valid for chat completion!`
  164. );
  165. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  166. this.openai.chat.completions
  167. .create({
  168. model: this.model,
  169. messages,
  170. temperature,
  171. })
  172. .catch((e) => {
  173. throw new Error(e.message);
  174. })
  175. );
  176. if (
  177. !result.output.hasOwnProperty("choices") ||
  178. result.output.choices.length === 0
  179. )
  180. return null;
  181. return {
  182. textResponse: result.output.choices[0].message.content,
  183. metrics: {
  184. prompt_tokens: result.output.usage?.prompt_tokens || 0,
  185. completion_tokens: result.output.usage?.completion_tokens || 0,
  186. total_tokens: result.output.usage?.total_tokens || 0,
  187. outputTps: result.output.usage?.completion_tokens / result.duration,
  188. duration: result.duration,
  189. },
  190. };
  191. }
  192. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  193. if (!(await this.isValidChatCompletionModel(this.model)))
  194. throw new Error(
  195. `TogetherAI chat: ${this.model} is not valid for chat completion!`
  196. );
  197. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  198. this.openai.chat.completions.create({
  199. model: this.model,
  200. stream: true,
  201. messages,
  202. temperature,
  203. }),
  204. messages,
  205. false
  206. );
  207. return measuredStreamRequest;
  208. }
  209. handleStream(response, stream, responseProps) {
  210. return handleDefaultStreamResponseV2(response, stream, responseProps);
  211. }
  212. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  213. async embedTextInput(textInput) {
  214. return await this.embedder.embedTextInput(textInput);
  215. }
  216. async embedChunks(textChunks = []) {
  217. return await this.embedder.embedChunks(textChunks);
  218. }
  219. async compressMessages(promptArgs = {}, rawHistory = []) {
  220. const { messageArrayCompressor } = require("../../helpers/chat");
  221. const messageArray = this.constructPrompt(promptArgs);
  222. return await messageArrayCompressor(this, messageArray, rawHistory);
  223. }
  224. }
  225. module.exports = {
  226. TogetherAiLLM,
  227. togetherAiModels,
  228. };