You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

254 lines
7.2 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const {
  3. clientAbortedHandler,
  4. writeResponseChunk,
  5. formatChatHistory,
  6. } = require("../../helpers/chat/responses");
  7. const {
  8. LLMPerformanceMonitor,
  9. } = require("../../helpers/chat/LLMPerformanceMonitor");
  10. const { v4: uuidv4 } = require("uuid");
  11. class KoboldCPPLLM {
  12. constructor(embedder = null, modelPreference = null) {
  13. const { OpenAI: OpenAIApi } = require("openai");
  14. if (!process.env.KOBOLD_CPP_BASE_PATH)
  15. throw new Error(
  16. "KoboldCPP must have a valid base path to use for the api."
  17. );
  18. this.basePath = process.env.KOBOLD_CPP_BASE_PATH;
  19. this.openai = new OpenAIApi({
  20. baseURL: this.basePath,
  21. apiKey: null,
  22. });
  23. this.model = modelPreference ?? process.env.KOBOLD_CPP_MODEL_PREF ?? null;
  24. if (!this.model) throw new Error("KoboldCPP must have a valid model set.");
  25. this.limits = {
  26. history: this.promptWindowLimit() * 0.15,
  27. system: this.promptWindowLimit() * 0.15,
  28. user: this.promptWindowLimit() * 0.7,
  29. };
  30. this.embedder = embedder ?? new NativeEmbedder();
  31. this.defaultTemp = 0.7;
  32. this.log(`Inference API: ${this.basePath} Model: ${this.model}`);
  33. }
  34. log(text, ...args) {
  35. console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
  36. }
  37. #appendContext(contextTexts = []) {
  38. if (!contextTexts || !contextTexts.length) return "";
  39. return (
  40. "\nContext:\n" +
  41. contextTexts
  42. .map((text, i) => {
  43. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  44. })
  45. .join("")
  46. );
  47. }
  48. streamingEnabled() {
  49. return "streamGetChatCompletion" in this;
  50. }
  51. static promptWindowLimit(_modelName) {
  52. const limit = process.env.KOBOLD_CPP_MODEL_TOKEN_LIMIT || 4096;
  53. if (!limit || isNaN(Number(limit)))
  54. throw new Error("No token context limit was set.");
  55. return Number(limit);
  56. }
  57. // Ensure the user set a value for the token limit
  58. // and if undefined - assume 4096 window.
  59. promptWindowLimit() {
  60. const limit = process.env.KOBOLD_CPP_MODEL_TOKEN_LIMIT || 4096;
  61. if (!limit || isNaN(Number(limit)))
  62. throw new Error("No token context limit was set.");
  63. return Number(limit);
  64. }
  65. // Short circuit since we have no idea if the model is valid or not
  66. // in pre-flight for generic endpoints
  67. isValidChatCompletionModel(_modelName = "") {
  68. return true;
  69. }
  70. /**
  71. * Generates appropriate content array for a message + attachments.
  72. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  73. * @returns {string|object[]}
  74. */
  75. #generateContent({ userPrompt, attachments = [] }) {
  76. if (!attachments.length) {
  77. return userPrompt;
  78. }
  79. const content = [{ type: "text", text: userPrompt }];
  80. for (let attachment of attachments) {
  81. content.push({
  82. type: "image_url",
  83. image_url: {
  84. url: attachment.contentString,
  85. },
  86. });
  87. }
  88. return content.flat();
  89. }
  90. /**
  91. * Construct the user prompt for this model.
  92. * @param {{attachments: import("../../helpers").Attachment[]}} param0
  93. * @returns
  94. */
  95. constructPrompt({
  96. systemPrompt = "",
  97. contextTexts = [],
  98. chatHistory = [],
  99. userPrompt = "",
  100. attachments = [],
  101. }) {
  102. const prompt = {
  103. role: "system",
  104. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  105. };
  106. return [
  107. prompt,
  108. ...formatChatHistory(chatHistory, this.#generateContent),
  109. {
  110. role: "user",
  111. content: this.#generateContent({ userPrompt, attachments }),
  112. },
  113. ];
  114. }
  115. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  116. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  117. this.openai.chat.completions
  118. .create({
  119. model: this.model,
  120. messages,
  121. temperature,
  122. })
  123. .catch((e) => {
  124. throw new Error(e.message);
  125. })
  126. );
  127. if (
  128. !result.output.hasOwnProperty("choices") ||
  129. result.output.choices.length === 0
  130. )
  131. return null;
  132. const promptTokens = LLMPerformanceMonitor.countTokens(messages);
  133. const completionTokens = LLMPerformanceMonitor.countTokens([
  134. { content: result.output.choices[0].message.content },
  135. ]);
  136. return {
  137. textResponse: result.output.choices[0].message.content,
  138. metrics: {
  139. prompt_tokens: promptTokens,
  140. completion_tokens: completionTokens,
  141. total_tokens: promptTokens + completionTokens,
  142. outputTps: completionTokens / result.duration,
  143. duration: result.duration,
  144. },
  145. };
  146. }
  147. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  148. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  149. this.openai.chat.completions.create({
  150. model: this.model,
  151. stream: true,
  152. messages,
  153. temperature,
  154. }),
  155. messages
  156. );
  157. return measuredStreamRequest;
  158. }
  159. handleStream(response, stream, responseProps) {
  160. const { uuid = uuidv4(), sources = [] } = responseProps;
  161. return new Promise(async (resolve) => {
  162. let fullText = "";
  163. let usage = {
  164. prompt_tokens: LLMPerformanceMonitor.countTokens(stream.messages || []),
  165. completion_tokens: 0,
  166. };
  167. const handleAbort = () => {
  168. usage.completion_tokens = LLMPerformanceMonitor.countTokens([
  169. { content: fullText },
  170. ]);
  171. stream?.endMeasurement(usage);
  172. clientAbortedHandler(resolve, fullText);
  173. };
  174. response.on("close", handleAbort);
  175. for await (const chunk of stream) {
  176. const message = chunk?.choices?.[0];
  177. const token = message?.delta?.content;
  178. if (token) {
  179. fullText += token;
  180. writeResponseChunk(response, {
  181. uuid,
  182. sources: [],
  183. type: "textResponseChunk",
  184. textResponse: token,
  185. close: false,
  186. error: false,
  187. });
  188. }
  189. // KoboldCPP finishes with "length" or "stop"
  190. if (
  191. message.finish_reason !== "null" &&
  192. (message.finish_reason === "length" ||
  193. message.finish_reason === "stop")
  194. ) {
  195. writeResponseChunk(response, {
  196. uuid,
  197. sources,
  198. type: "textResponseChunk",
  199. textResponse: "",
  200. close: true,
  201. error: false,
  202. });
  203. response.removeListener("close", handleAbort);
  204. usage.completion_tokens = LLMPerformanceMonitor.countTokens([
  205. { content: fullText },
  206. ]);
  207. stream?.endMeasurement(usage);
  208. resolve(fullText);
  209. }
  210. }
  211. });
  212. }
  213. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  214. async embedTextInput(textInput) {
  215. return await this.embedder.embedTextInput(textInput);
  216. }
  217. async embedChunks(textChunks = []) {
  218. return await this.embedder.embedChunks(textChunks);
  219. }
  220. async compressMessages(promptArgs = {}, rawHistory = []) {
  221. const { messageArrayCompressor } = require("../../helpers/chat");
  222. const messageArray = this.constructPrompt(promptArgs);
  223. return await messageArrayCompressor(this, messageArray, rawHistory);
  224. }
  225. }
  226. module.exports = {
  227. KoboldCPPLLM,
  228. };