You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

500 lines
16 KiB

11 months ago
  1. const fs = require("fs");
  2. const path = require("path");
  3. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  4. const {
  5. LLMPerformanceMonitor,
  6. } = require("../../helpers/chat/LLMPerformanceMonitor");
  7. const {
  8. writeResponseChunk,
  9. clientAbortedHandler,
  10. formatChatHistory,
  11. } = require("../../helpers/chat/responses");
  12. const { MODEL_MAP } = require("../modelMap");
  13. const { defaultGeminiModels, v1BetaModels } = require("./defaultModels");
  14. const { safeJsonParse } = require("../../http");
  15. const cacheFolder = path.resolve(
  16. process.env.STORAGE_DIR
  17. ? path.resolve(process.env.STORAGE_DIR, "models", "gemini")
  18. : path.resolve(__dirname, `../../../storage/models/gemini`)
  19. );
  20. class GeminiLLM {
  21. constructor(embedder = null, modelPreference = null) {
  22. if (!process.env.GEMINI_API_KEY)
  23. throw new Error("No Gemini API key was set.");
  24. // Docs: https://ai.google.dev/tutorials/node_quickstart
  25. const { GoogleGenerativeAI } = require("@google/generative-ai");
  26. const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
  27. this.model =
  28. modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
  29. this.gemini = genAI.getGenerativeModel(
  30. { model: this.model },
  31. {
  32. apiVersion:
  33. /**
  34. * There are some models that are only available in the v1beta API
  35. * and some models that are only available in the v1 API
  36. * generally, v1beta models have `exp` in the name, but not always
  37. * so we check for both against a static list as well.
  38. * @see {v1BetaModels}
  39. */
  40. this.model.includes("exp") || v1BetaModels.includes(this.model)
  41. ? "v1beta"
  42. : "v1",
  43. }
  44. );
  45. this.limits = {
  46. history: this.promptWindowLimit() * 0.15,
  47. system: this.promptWindowLimit() * 0.15,
  48. user: this.promptWindowLimit() * 0.7,
  49. };
  50. this.embedder = embedder ?? new NativeEmbedder();
  51. this.defaultTemp = 0.7; // not used for Gemini
  52. this.safetyThreshold = this.#fetchSafetyThreshold();
  53. if (!fs.existsSync(cacheFolder))
  54. fs.mkdirSync(cacheFolder, { recursive: true });
  55. this.cacheModelPath = path.resolve(cacheFolder, "models.json");
  56. this.cacheAtPath = path.resolve(cacheFolder, ".cached_at");
  57. this.#log(
  58. `Initialized with model: ${this.model} (${this.promptWindowLimit()})`
  59. );
  60. }
  61. #log(text, ...args) {
  62. console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
  63. }
  64. // This checks if the .cached_at file has a timestamp that is more than 1Week (in millis)
  65. // from the current date. If it is, then we will refetch the API so that all the models are up
  66. // to date.
  67. static cacheIsStale() {
  68. const MAX_STALE = 6.048e8; // 1 Week in MS
  69. if (!fs.existsSync(path.resolve(cacheFolder, ".cached_at"))) return true;
  70. const now = Number(new Date());
  71. const timestampMs = Number(
  72. fs.readFileSync(path.resolve(cacheFolder, ".cached_at"))
  73. );
  74. return now - timestampMs > MAX_STALE;
  75. }
  76. #appendContext(contextTexts = []) {
  77. if (!contextTexts || !contextTexts.length) return "";
  78. return (
  79. "\nContext:\n" +
  80. contextTexts
  81. .map((text, i) => {
  82. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  83. })
  84. .join("")
  85. );
  86. }
  87. // BLOCK_NONE can be a special candidate for some fields
  88. // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes#how_to_remove_automated_response_blocking_for_select_safety_attributes
  89. // so if you are wondering why BLOCK_NONE still failed, the link above will explain why.
  90. #fetchSafetyThreshold() {
  91. const threshold =
  92. process.env.GEMINI_SAFETY_SETTING ?? "BLOCK_MEDIUM_AND_ABOVE";
  93. const safetyThresholds = [
  94. "BLOCK_NONE",
  95. "BLOCK_ONLY_HIGH",
  96. "BLOCK_MEDIUM_AND_ABOVE",
  97. "BLOCK_LOW_AND_ABOVE",
  98. ];
  99. return safetyThresholds.includes(threshold)
  100. ? threshold
  101. : "BLOCK_MEDIUM_AND_ABOVE";
  102. }
  103. #safetySettings() {
  104. return [
  105. {
  106. category: "HARM_CATEGORY_HATE_SPEECH",
  107. threshold: this.safetyThreshold,
  108. },
  109. {
  110. category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
  111. threshold: this.safetyThreshold,
  112. },
  113. { category: "HARM_CATEGORY_HARASSMENT", threshold: this.safetyThreshold },
  114. {
  115. category: "HARM_CATEGORY_DANGEROUS_CONTENT",
  116. threshold: this.safetyThreshold,
  117. },
  118. ];
  119. }
  120. streamingEnabled() {
  121. return "streamGetChatCompletion" in this;
  122. }
  123. static promptWindowLimit(modelName) {
  124. try {
  125. const cacheModelPath = path.resolve(cacheFolder, "models.json");
  126. if (!fs.existsSync(cacheModelPath))
  127. return MODEL_MAP.gemini[modelName] ?? 30_720;
  128. const models = safeJsonParse(fs.readFileSync(cacheModelPath));
  129. const model = models.find((model) => model.id === modelName);
  130. if (!model)
  131. throw new Error(
  132. "Model not found in cache - falling back to default model."
  133. );
  134. return model.contextWindow;
  135. } catch (e) {
  136. console.error(`GeminiLLM:promptWindowLimit`, e.message);
  137. return MODEL_MAP.gemini[modelName] ?? 30_720;
  138. }
  139. }
  140. promptWindowLimit() {
  141. try {
  142. if (!fs.existsSync(this.cacheModelPath))
  143. return MODEL_MAP.gemini[this.model] ?? 30_720;
  144. const models = safeJsonParse(fs.readFileSync(this.cacheModelPath));
  145. const model = models.find((model) => model.id === this.model);
  146. if (!model)
  147. throw new Error(
  148. "Model not found in cache - falling back to default model."
  149. );
  150. return model.contextWindow;
  151. } catch (e) {
  152. console.error(`GeminiLLM:promptWindowLimit`, e.message);
  153. return MODEL_MAP.gemini[this.model] ?? 30_720;
  154. }
  155. }
  156. /**
  157. * Fetches Gemini models from the Google Generative AI API
  158. * @param {string} apiKey - The API key to use for the request
  159. * @param {number} limit - The maximum number of models to fetch
  160. * @param {string} pageToken - The page token to use for pagination
  161. * @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
  162. */
  163. static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
  164. if (!apiKey) return [];
  165. if (fs.existsSync(cacheFolder) && !this.cacheIsStale()) {
  166. console.log(
  167. `\x1b[32m[GeminiLLM]\x1b[0m Using cached models API response.`
  168. );
  169. return safeJsonParse(
  170. fs.readFileSync(path.resolve(cacheFolder, "models.json"))
  171. );
  172. }
  173. const url = new URL(
  174. "https://generativelanguage.googleapis.com/v1beta/models"
  175. );
  176. url.searchParams.set("pageSize", limit);
  177. url.searchParams.set("key", apiKey);
  178. if (pageToken) url.searchParams.set("pageToken", pageToken);
  179. let success = false;
  180. const models = await fetch(url.toString(), {
  181. method: "GET",
  182. headers: { "Content-Type": "application/json" },
  183. })
  184. .then((res) => res.json())
  185. .then((data) => {
  186. if (data.error) throw new Error(data.error.message);
  187. return data.models ?? [];
  188. })
  189. .then((models) => {
  190. success = true;
  191. return models
  192. .filter(
  193. (model) => !model.displayName.toLowerCase().includes("tuning")
  194. )
  195. .filter((model) =>
  196. model.supportedGenerationMethods.includes("generateContent")
  197. ) // Only generateContent is supported
  198. .map((model) => {
  199. return {
  200. id: model.name.split("/").pop(),
  201. name: model.displayName,
  202. contextWindow: model.inputTokenLimit,
  203. experimental: model.name.includes("exp"),
  204. };
  205. });
  206. })
  207. .catch((e) => {
  208. console.error(`Gemini:getGeminiModels`, e.message);
  209. success = false;
  210. return defaultGeminiModels;
  211. });
  212. if (success) {
  213. console.log(
  214. `\x1b[32m[GeminiLLM]\x1b[0m Writing cached models API response to disk.`
  215. );
  216. if (!fs.existsSync(cacheFolder))
  217. fs.mkdirSync(cacheFolder, { recursive: true });
  218. fs.writeFileSync(
  219. path.resolve(cacheFolder, "models.json"),
  220. JSON.stringify(models)
  221. );
  222. fs.writeFileSync(
  223. path.resolve(cacheFolder, ".cached_at"),
  224. new Date().getTime().toString()
  225. );
  226. }
  227. return models;
  228. }
  229. /**
  230. * Checks if a model is valid for chat completion (unused)
  231. * @deprecated
  232. * @param {string} modelName - The name of the model to check
  233. * @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid
  234. */
  235. async isValidChatCompletionModel(modelName = "") {
  236. const models = await this.fetchModels(process.env.GEMINI_API_KEY);
  237. return models.some((model) => model.id === modelName);
  238. }
  239. /**
  240. * Generates appropriate content array for a message + attachments.
  241. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  242. * @returns {string|object[]}
  243. */
  244. #generateContent({ userPrompt, attachments = [] }) {
  245. if (!attachments.length) {
  246. return userPrompt;
  247. }
  248. const content = [{ text: userPrompt }];
  249. for (let attachment of attachments) {
  250. content.push({
  251. inlineData: {
  252. data: attachment.contentString.split("base64,")[1],
  253. mimeType: attachment.mime,
  254. },
  255. });
  256. }
  257. return content.flat();
  258. }
  259. constructPrompt({
  260. systemPrompt = "",
  261. contextTexts = [],
  262. chatHistory = [],
  263. userPrompt = "",
  264. attachments = [],
  265. }) {
  266. const prompt = {
  267. role: "system",
  268. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  269. };
  270. return [
  271. prompt,
  272. { role: "assistant", content: "Okay." },
  273. ...formatChatHistory(chatHistory, this.#generateContent),
  274. {
  275. role: "USER_PROMPT",
  276. content: this.#generateContent({ userPrompt, attachments }),
  277. },
  278. ];
  279. }
  280. // This will take an OpenAi format message array and only pluck valid roles from it.
  281. formatMessages(messages = []) {
  282. // Gemini roles are either user || model.
  283. // and all "content" is relabeled to "parts"
  284. const allMessages = messages
  285. .map((message) => {
  286. if (message.role === "system")
  287. return { role: "user", parts: [{ text: message.content }] };
  288. if (message.role === "user") {
  289. // If the content is an array - then we have already formatted the context so return it directly.
  290. if (Array.isArray(message.content))
  291. return { role: "user", parts: message.content };
  292. // Otherwise, this was a regular user message with no attachments
  293. // so we need to format it for Gemini
  294. return { role: "user", parts: [{ text: message.content }] };
  295. }
  296. if (message.role === "assistant")
  297. return { role: "model", parts: [{ text: message.content }] };
  298. return null;
  299. })
  300. .filter((msg) => !!msg);
  301. // Specifically, Google cannot have the last sent message be from a user with no assistant reply
  302. // otherwise it will crash. So if the last item is from the user, it was not completed so pop it off
  303. // the history.
  304. if (
  305. allMessages.length > 0 &&
  306. allMessages[allMessages.length - 1].role === "user"
  307. )
  308. allMessages.pop();
  309. // Validate that after every user message, there is a model message
  310. // sometimes when using gemini we try to compress messages in order to retain as
  311. // much context as possible but this may mess up the order of the messages that the gemini model expects
  312. // we do this check to work around the edge case where 2 user prompts may be next to each other, in the message array
  313. for (let i = 0; i < allMessages.length; i++) {
  314. if (
  315. allMessages[i].role === "user" &&
  316. i < allMessages.length - 1 &&
  317. allMessages[i + 1].role !== "model"
  318. ) {
  319. allMessages.splice(i + 1, 0, {
  320. role: "model",
  321. parts: [{ text: "Okay." }],
  322. });
  323. }
  324. }
  325. return allMessages;
  326. }
  327. async getChatCompletion(messages = [], _opts = {}) {
  328. const prompt = messages.find(
  329. (chat) => chat.role === "USER_PROMPT"
  330. )?.content;
  331. const chatThread = this.gemini.startChat({
  332. history: this.formatMessages(messages),
  333. safetySettings: this.#safetySettings(),
  334. });
  335. const { output: result, duration } =
  336. await LLMPerformanceMonitor.measureAsyncFunction(
  337. chatThread.sendMessage(prompt)
  338. );
  339. const responseText = result.response.text();
  340. if (!responseText) throw new Error("Gemini: No response could be parsed.");
  341. const promptTokens = LLMPerformanceMonitor.countTokens(messages);
  342. const completionTokens = LLMPerformanceMonitor.countTokens([
  343. { content: responseText },
  344. ]);
  345. return {
  346. textResponse: responseText,
  347. metrics: {
  348. prompt_tokens: promptTokens,
  349. completion_tokens: completionTokens,
  350. total_tokens: promptTokens + completionTokens,
  351. outputTps: (promptTokens + completionTokens) / duration,
  352. duration,
  353. },
  354. };
  355. }
  356. async streamGetChatCompletion(messages = [], _opts = {}) {
  357. const prompt = messages.find(
  358. (chat) => chat.role === "USER_PROMPT"
  359. )?.content;
  360. const chatThread = this.gemini.startChat({
  361. history: this.formatMessages(messages),
  362. safetySettings: this.#safetySettings(),
  363. });
  364. const responseStream = await LLMPerformanceMonitor.measureStream(
  365. (await chatThread.sendMessageStream(prompt)).stream,
  366. messages
  367. );
  368. if (!responseStream)
  369. throw new Error("Could not stream response stream from Gemini.");
  370. return responseStream;
  371. }
  372. async compressMessages(promptArgs = {}, rawHistory = []) {
  373. const { messageArrayCompressor } = require("../../helpers/chat");
  374. const messageArray = this.constructPrompt(promptArgs);
  375. return await messageArrayCompressor(this, messageArray, rawHistory);
  376. }
  377. handleStream(response, stream, responseProps) {
  378. const { uuid = uuidv4(), sources = [] } = responseProps;
  379. // Usage is not available for Gemini streams
  380. // so we need to calculate the completion tokens manually
  381. // because 1 chunk != 1 token in gemini responses and it buffers
  382. // many tokens before sending them to the client as a "chunk"
  383. return new Promise(async (resolve) => {
  384. let fullText = "";
  385. // Establish listener to early-abort a streaming response
  386. // in case things go sideways or the user does not like the response.
  387. // We preserve the generated text but continue as if chat was completed
  388. // to preserve previously generated content.
  389. const handleAbort = () => {
  390. stream?.endMeasurement({
  391. completion_tokens: LLMPerformanceMonitor.countTokens([
  392. { content: fullText },
  393. ]),
  394. });
  395. clientAbortedHandler(resolve, fullText);
  396. };
  397. response.on("close", handleAbort);
  398. for await (const chunk of stream) {
  399. let chunkText;
  400. try {
  401. // Due to content sensitivity we cannot always get the function .text();
  402. // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-attributes#gemini-TASK-samples-nodejs
  403. // and it is not possible to unblock or disable this safety protocol without being allowlisted by Google.
  404. chunkText = chunk.text();
  405. } catch (e) {
  406. chunkText = e.message;
  407. writeResponseChunk(response, {
  408. uuid,
  409. sources: [],
  410. type: "abort",
  411. textResponse: null,
  412. close: true,
  413. error: e.message,
  414. });
  415. stream?.endMeasurement({ completion_tokens: 0 });
  416. resolve(e.message);
  417. return;
  418. }
  419. fullText += chunkText;
  420. writeResponseChunk(response, {
  421. uuid,
  422. sources: [],
  423. type: "textResponseChunk",
  424. textResponse: chunk.text(),
  425. close: false,
  426. error: false,
  427. });
  428. }
  429. writeResponseChunk(response, {
  430. uuid,
  431. sources,
  432. type: "textResponseChunk",
  433. textResponse: "",
  434. close: true,
  435. error: false,
  436. });
  437. response.removeListener("close", handleAbort);
  438. stream?.endMeasurement({
  439. completion_tokens: LLMPerformanceMonitor.countTokens([
  440. { content: fullText },
  441. ]),
  442. });
  443. resolve(fullText);
  444. });
  445. }
  446. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  447. async embedTextInput(textInput) {
  448. return await this.embedder.embedTextInput(textInput);
  449. }
  450. async embedChunks(textChunks = []) {
  451. return await this.embedder.embedChunks(textChunks);
  452. }
  453. }
  454. module.exports = {
  455. GeminiLLM,
  456. };