You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
7.4 KiB

11 months ago
  1. const { v4 } = require("uuid");
  2. const { writeResponseChunk } = require("../../helpers/chat/responses");
  3. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  4. const { MODEL_MAP } = require("../modelMap");
  5. const {
  6. LLMPerformanceMonitor,
  7. } = require("../../helpers/chat/LLMPerformanceMonitor");
  8. class CohereLLM {
  9. constructor(embedder = null) {
  10. const { CohereClient } = require("cohere-ai");
  11. if (!process.env.COHERE_API_KEY)
  12. throw new Error("No Cohere API key was set.");
  13. const cohere = new CohereClient({
  14. token: process.env.COHERE_API_KEY,
  15. });
  16. this.cohere = cohere;
  17. this.model = process.env.COHERE_MODEL_PREF;
  18. this.limits = {
  19. history: this.promptWindowLimit() * 0.15,
  20. system: this.promptWindowLimit() * 0.15,
  21. user: this.promptWindowLimit() * 0.7,
  22. };
  23. this.embedder = embedder ?? new NativeEmbedder();
  24. }
  25. #appendContext(contextTexts = []) {
  26. if (!contextTexts || !contextTexts.length) return "";
  27. return (
  28. "\nContext:\n" +
  29. contextTexts
  30. .map((text, i) => {
  31. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  32. })
  33. .join("")
  34. );
  35. }
  36. #convertChatHistoryCohere(chatHistory = []) {
  37. let cohereHistory = [];
  38. chatHistory.forEach((message) => {
  39. switch (message.role) {
  40. case "system":
  41. cohereHistory.push({ role: "SYSTEM", message: message.content });
  42. break;
  43. case "user":
  44. cohereHistory.push({ role: "USER", message: message.content });
  45. break;
  46. case "assistant":
  47. cohereHistory.push({ role: "CHATBOT", message: message.content });
  48. break;
  49. }
  50. });
  51. return cohereHistory;
  52. }
  53. streamingEnabled() {
  54. return "streamGetChatCompletion" in this;
  55. }
  56. static promptWindowLimit(modelName) {
  57. return MODEL_MAP.cohere[modelName] ?? 4_096;
  58. }
  59. promptWindowLimit() {
  60. return MODEL_MAP.cohere[this.model] ?? 4_096;
  61. }
  62. async isValidChatCompletionModel(model = "") {
  63. const validModels = [
  64. "command-r",
  65. "command-r-plus",
  66. "command",
  67. "command-light",
  68. "command-nightly",
  69. "command-light-nightly",
  70. ];
  71. return validModels.includes(model);
  72. }
  73. constructPrompt({
  74. systemPrompt = "",
  75. contextTexts = [],
  76. chatHistory = [],
  77. userPrompt = "",
  78. }) {
  79. const prompt = {
  80. role: "system",
  81. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  82. };
  83. return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
  84. }
  85. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  86. if (!(await this.isValidChatCompletionModel(this.model)))
  87. throw new Error(
  88. `Cohere chat: ${this.model} is not valid for chat completion!`
  89. );
  90. const message = messages[messages.length - 1].content; // Get the last message
  91. const cohereHistory = this.#convertChatHistoryCohere(messages.slice(0, -1)); // Remove the last message and convert to Cohere
  92. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  93. this.cohere.chat({
  94. model: this.model,
  95. message: message,
  96. chatHistory: cohereHistory,
  97. temperature,
  98. })
  99. );
  100. if (
  101. !result.output.hasOwnProperty("text") ||
  102. result.output.text.length === 0
  103. )
  104. return null;
  105. const promptTokens = result.output.meta?.tokens?.inputTokens || 0;
  106. const completionTokens = result.output.meta?.tokens?.outputTokens || 0;
  107. return {
  108. textResponse: result.output.text,
  109. metrics: {
  110. prompt_tokens: promptTokens,
  111. completion_tokens: completionTokens,
  112. total_tokens: promptTokens + completionTokens,
  113. outputTps: completionTokens / result.duration,
  114. duration: result.duration,
  115. },
  116. };
  117. }
  118. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  119. if (!(await this.isValidChatCompletionModel(this.model)))
  120. throw new Error(
  121. `Cohere chat: ${this.model} is not valid for chat completion!`
  122. );
  123. const message = messages[messages.length - 1].content; // Get the last message
  124. const cohereHistory = this.#convertChatHistoryCohere(messages.slice(0, -1)); // Remove the last message and convert to Cohere
  125. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  126. this.cohere.chatStream({
  127. model: this.model,
  128. message: message,
  129. chatHistory: cohereHistory,
  130. temperature,
  131. }),
  132. messages,
  133. false
  134. );
  135. return measuredStreamRequest;
  136. }
  137. /**
  138. * Handles the stream response from the Cohere API.
  139. * @param {Object} response - the response object
  140. * @param {import('../../helpers/chat/LLMPerformanceMonitor').MonitoredStream} stream - the stream response from the Cohere API w/tracking
  141. * @param {Object} responseProps - the response properties
  142. * @returns {Promise<string>}
  143. */
  144. async handleStream(response, stream, responseProps) {
  145. return new Promise(async (resolve) => {
  146. const { uuid = v4(), sources = [] } = responseProps;
  147. let fullText = "";
  148. let usage = {
  149. prompt_tokens: 0,
  150. completion_tokens: 0,
  151. };
  152. const handleAbort = () => {
  153. writeResponseChunk(response, {
  154. uuid,
  155. sources,
  156. type: "abort",
  157. textResponse: fullText,
  158. close: true,
  159. error: false,
  160. });
  161. response.removeListener("close", handleAbort);
  162. stream.endMeasurement(usage);
  163. resolve(fullText);
  164. };
  165. response.on("close", handleAbort);
  166. try {
  167. for await (const chat of stream) {
  168. if (chat.eventType === "stream-end") {
  169. const usageMetrics = chat?.response?.meta?.tokens || {};
  170. usage.prompt_tokens = usageMetrics.inputTokens || 0;
  171. usage.completion_tokens = usageMetrics.outputTokens || 0;
  172. }
  173. if (chat.eventType === "text-generation") {
  174. const text = chat.text;
  175. fullText += text;
  176. writeResponseChunk(response, {
  177. uuid,
  178. sources,
  179. type: "textResponseChunk",
  180. textResponse: text,
  181. close: false,
  182. error: false,
  183. });
  184. }
  185. }
  186. writeResponseChunk(response, {
  187. uuid,
  188. sources,
  189. type: "textResponseChunk",
  190. textResponse: "",
  191. close: true,
  192. error: false,
  193. });
  194. response.removeListener("close", handleAbort);
  195. stream.endMeasurement(usage);
  196. resolve(fullText);
  197. } catch (error) {
  198. writeResponseChunk(response, {
  199. uuid,
  200. sources,
  201. type: "abort",
  202. textResponse: null,
  203. close: true,
  204. error: error.message,
  205. });
  206. response.removeListener("close", handleAbort);
  207. stream.endMeasurement(usage);
  208. resolve(fullText);
  209. }
  210. });
  211. }
  212. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  213. async embedTextInput(textInput) {
  214. return await this.embedder.embedTextInput(textInput);
  215. }
  216. async embedChunks(textChunks = []) {
  217. return await this.embedder.embedChunks(textChunks);
  218. }
  219. async compressMessages(promptArgs = {}, rawHistory = []) {
  220. const { messageArrayCompressor } = require("../../helpers/chat");
  221. const messageArray = this.constructPrompt(promptArgs);
  222. return await messageArrayCompressor(this, messageArray, rawHistory);
  223. }
  224. }
  225. module.exports = {
  226. CohereLLM,
  227. };