You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

304 lines
9.4 KiB

11 months ago
  1. const {
  2. writeResponseChunk,
  3. clientAbortedHandler,
  4. formatChatHistory,
  5. } = require("../../helpers/chat/responses");
  6. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  7. const {
  8. LLMPerformanceMonitor,
  9. } = require("../../helpers/chat/LLMPerformanceMonitor");
  10. const { Ollama } = require("ollama");
  11. // Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md
  12. class OllamaAILLM {
  13. constructor(embedder = null, modelPreference = null) {
  14. if (!process.env.OLLAMA_BASE_PATH)
  15. throw new Error("No Ollama Base Path was set.");
  16. this.basePath = process.env.OLLAMA_BASE_PATH;
  17. this.model = modelPreference || process.env.OLLAMA_MODEL_PREF;
  18. this.performanceMode = process.env.OLLAMA_PERFORMANCE_MODE || "base";
  19. this.keepAlive = process.env.OLLAMA_KEEP_ALIVE_TIMEOUT
  20. ? Number(process.env.OLLAMA_KEEP_ALIVE_TIMEOUT)
  21. : 300; // Default 5-minute timeout for Ollama model loading.
  22. this.limits = {
  23. history: this.promptWindowLimit() * 0.15,
  24. system: this.promptWindowLimit() * 0.15,
  25. user: this.promptWindowLimit() * 0.7,
  26. };
  27. this.client = new Ollama({ host: this.basePath });
  28. this.embedder = embedder ?? new NativeEmbedder();
  29. this.defaultTemp = 0.7;
  30. this.#log(
  31. `OllamaAILLM initialized with\nmodel: ${this.model}\nperf: ${this.performanceMode}\nn_ctx: ${this.promptWindowLimit()}`
  32. );
  33. }
  34. #log(text, ...args) {
  35. console.log(`\x1b[32m[Ollama]\x1b[0m ${text}`, ...args);
  36. }
  37. #appendContext(contextTexts = []) {
  38. if (!contextTexts || !contextTexts.length) return "";
  39. return (
  40. "\nContext:\n" +
  41. contextTexts
  42. .map((text, i) => {
  43. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  44. })
  45. .join("")
  46. );
  47. }
  48. streamingEnabled() {
  49. return "streamGetChatCompletion" in this;
  50. }
  51. static promptWindowLimit(_modelName) {
  52. const limit = process.env.OLLAMA_MODEL_TOKEN_LIMIT || 4096;
  53. if (!limit || isNaN(Number(limit)))
  54. throw new Error("No Ollama token context limit was set.");
  55. return Number(limit);
  56. }
  57. // Ensure the user set a value for the token limit
  58. // and if undefined - assume 4096 window.
  59. promptWindowLimit() {
  60. const limit = process.env.OLLAMA_MODEL_TOKEN_LIMIT || 4096;
  61. if (!limit || isNaN(Number(limit)))
  62. throw new Error("No Ollama token context limit was set.");
  63. return Number(limit);
  64. }
  65. async isValidChatCompletionModel(_ = "") {
  66. return true;
  67. }
  68. /**
  69. * Generates appropriate content array for a message + attachments.
  70. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  71. * @returns {{content: string, images: string[]}}
  72. */
  73. #generateContent({ userPrompt, attachments = [] }) {
  74. if (!attachments.length) return { content: userPrompt };
  75. const images = attachments.map(
  76. (attachment) => attachment.contentString.split("base64,").slice(-1)[0]
  77. );
  78. return { content: userPrompt, images };
  79. }
  80. /**
  81. * Handles errors from the Ollama API to make them more user friendly.
  82. * @param {Error} e
  83. */
  84. #errorHandler(e) {
  85. switch (e.message) {
  86. case "fetch failed":
  87. throw new Error(
  88. "Your Ollama instance could not be reached or is not responding. Please make sure it is running the API server and your connection information is correct in AnythingLLM."
  89. );
  90. default:
  91. return e;
  92. }
  93. }
  94. /**
  95. * Construct the user prompt for this model.
  96. * @param {{attachments: import("../../helpers").Attachment[]}} param0
  97. * @returns
  98. */
  99. constructPrompt({
  100. systemPrompt = "",
  101. contextTexts = [],
  102. chatHistory = [],
  103. userPrompt = "",
  104. attachments = [],
  105. }) {
  106. const prompt = {
  107. role: "system",
  108. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  109. };
  110. return [
  111. prompt,
  112. ...formatChatHistory(chatHistory, this.#generateContent, "spread"),
  113. {
  114. role: "user",
  115. ...this.#generateContent({ userPrompt, attachments }),
  116. },
  117. ];
  118. }
  119. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  120. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  121. this.client
  122. .chat({
  123. model: this.model,
  124. stream: false,
  125. messages,
  126. keep_alive: this.keepAlive,
  127. options: {
  128. temperature,
  129. use_mlock: true,
  130. // There are currently only two performance settings so if its not "base" - its max context.
  131. ...(this.performanceMode === "base"
  132. ? {}
  133. : { num_ctx: this.promptWindowLimit() }),
  134. },
  135. })
  136. .then((res) => {
  137. return {
  138. content: res.message.content,
  139. usage: {
  140. prompt_tokens: res.prompt_eval_count,
  141. completion_tokens: res.eval_count,
  142. total_tokens: res.prompt_eval_count + res.eval_count,
  143. },
  144. };
  145. })
  146. .catch((e) => {
  147. throw new Error(
  148. `Ollama::getChatCompletion failed to communicate with Ollama. ${this.#errorHandler(e).message}`
  149. );
  150. })
  151. );
  152. if (!result.output.content || !result.output.content.length)
  153. throw new Error(`Ollama::getChatCompletion text response was empty.`);
  154. return {
  155. textResponse: result.output.content,
  156. metrics: {
  157. prompt_tokens: result.output.usage.prompt_tokens,
  158. completion_tokens: result.output.usage.completion_tokens,
  159. total_tokens: result.output.usage.total_tokens,
  160. outputTps: result.output.usage.completion_tokens / result.duration,
  161. duration: result.duration,
  162. },
  163. };
  164. }
  165. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  166. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  167. this.client.chat({
  168. model: this.model,
  169. stream: true,
  170. messages,
  171. keep_alive: this.keepAlive,
  172. options: {
  173. temperature,
  174. use_mlock: true,
  175. // There are currently only two performance settings so if its not "base" - its max context.
  176. ...(this.performanceMode === "base"
  177. ? {}
  178. : { num_ctx: this.promptWindowLimit() }),
  179. },
  180. }),
  181. messages,
  182. false
  183. ).catch((e) => {
  184. throw this.#errorHandler(e);
  185. });
  186. return measuredStreamRequest;
  187. }
  188. /**
  189. * Handles streaming responses from Ollama.
  190. * @param {import("express").Response} response
  191. * @param {import("../../helpers/chat/LLMPerformanceMonitor").MonitoredStream} stream
  192. * @param {import("express").Request} request
  193. * @returns {Promise<string>}
  194. */
  195. handleStream(response, stream, responseProps) {
  196. const { uuid = uuidv4(), sources = [] } = responseProps;
  197. return new Promise(async (resolve) => {
  198. let fullText = "";
  199. let usage = {
  200. prompt_tokens: 0,
  201. completion_tokens: 0,
  202. };
  203. // Establish listener to early-abort a streaming response
  204. // in case things go sideways or the user does not like the response.
  205. // We preserve the generated text but continue as if chat was completed
  206. // to preserve previously generated content.
  207. const handleAbort = () => {
  208. stream?.endMeasurement(usage);
  209. clientAbortedHandler(resolve, fullText);
  210. };
  211. response.on("close", handleAbort);
  212. try {
  213. for await (const chunk of stream) {
  214. if (chunk === undefined)
  215. throw new Error(
  216. "Stream returned undefined chunk. Aborting reply - check model provider logs."
  217. );
  218. if (chunk.done) {
  219. usage.prompt_tokens = chunk.prompt_eval_count;
  220. usage.completion_tokens = chunk.eval_count;
  221. writeResponseChunk(response, {
  222. uuid,
  223. sources,
  224. type: "textResponseChunk",
  225. textResponse: "",
  226. close: true,
  227. error: false,
  228. });
  229. response.removeListener("close", handleAbort);
  230. stream?.endMeasurement(usage);
  231. resolve(fullText);
  232. break;
  233. }
  234. if (chunk.hasOwnProperty("message")) {
  235. const content = chunk.message.content;
  236. fullText += content;
  237. writeResponseChunk(response, {
  238. uuid,
  239. sources,
  240. type: "textResponseChunk",
  241. textResponse: content,
  242. close: false,
  243. error: false,
  244. });
  245. }
  246. }
  247. } catch (error) {
  248. writeResponseChunk(response, {
  249. uuid,
  250. sources: [],
  251. type: "textResponseChunk",
  252. textResponse: "",
  253. close: true,
  254. error: `Ollama:streaming - could not stream chat. ${
  255. error?.cause ?? error.message
  256. }`,
  257. });
  258. response.removeListener("close", handleAbort);
  259. stream?.endMeasurement(usage);
  260. resolve(fullText);
  261. }
  262. });
  263. }
  264. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  265. async embedTextInput(textInput) {
  266. return await this.embedder.embedTextInput(textInput);
  267. }
  268. async embedChunks(textChunks = []) {
  269. return await this.embedder.embedChunks(textChunks);
  270. }
  271. async compressMessages(promptArgs = {}, rawHistory = []) {
  272. const { messageArrayCompressor } = require("../../helpers/chat");
  273. const messageArray = this.constructPrompt(promptArgs);
  274. return await messageArrayCompressor(this, messageArray, rawHistory);
  275. }
  276. }
  277. module.exports = {
  278. OllamaAILLM,
  279. };