You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

226 lines
7.7 KiB

11 months ago
  1. const { v4: uuidv4 } = require("uuid");
  2. const { getVectorDbClass, getLLMProvider } = require("../helpers");
  3. const { chatPrompt, sourceIdentifier } = require("./index");
  4. const { EmbedChats } = require("../../models/embedChats");
  5. const {
  6. convertToPromptHistory,
  7. writeResponseChunk,
  8. } = require("../helpers/chat/responses");
  9. const { DocumentManager } = require("../DocumentManager");
  10. async function streamChatWithForEmbed(
  11. response,
  12. /** @type {import("@prisma/client").embed_configs & {workspace?: import("@prisma/client").workspaces}} */
  13. embed,
  14. /** @type {String} */
  15. message,
  16. /** @type {String} */
  17. sessionId,
  18. { promptOverride, modelOverride, temperatureOverride, username }
  19. ) {
  20. const chatMode = embed.chat_mode;
  21. const chatModel = embed.allow_model_override ? modelOverride : null;
  22. // If there are overrides in request & they are permitted, override the default workspace ref information.
  23. if (embed.allow_prompt_override)
  24. embed.workspace.openAiPrompt = promptOverride;
  25. if (embed.allow_temperature_override)
  26. embed.workspace.openAiTemp = parseFloat(temperatureOverride);
  27. const uuid = uuidv4();
  28. const LLMConnector = getLLMProvider({
  29. provider: embed?.workspace?.chatProvider,
  30. model: chatModel ?? embed.workspace?.chatModel,
  31. });
  32. const VectorDb = getVectorDbClass();
  33. const messageLimit = 20;
  34. const hasVectorizedSpace = await VectorDb.hasNamespace(embed.workspace.slug);
  35. const embeddingsCount = await VectorDb.namespaceCount(embed.workspace.slug);
  36. // User is trying to query-mode chat a workspace that has no data in it - so
  37. // we should exit early as no information can be found under these conditions.
  38. if ((!hasVectorizedSpace || embeddingsCount === 0) && chatMode === "query") {
  39. writeResponseChunk(response, {
  40. id: uuid,
  41. type: "textResponse",
  42. textResponse:
  43. "I do not have enough information to answer that. Try another question.",
  44. sources: [],
  45. close: true,
  46. error: null,
  47. });
  48. return;
  49. }
  50. let completeText;
  51. let metrics = {};
  52. let contextTexts = [];
  53. let sources = [];
  54. let pinnedDocIdentifiers = [];
  55. const { rawHistory, chatHistory } = await recentEmbedChatHistory(
  56. sessionId,
  57. embed,
  58. messageLimit
  59. );
  60. // See stream.js comment for more information on this implementation.
  61. await new DocumentManager({
  62. workspace: embed.workspace,
  63. maxTokens: LLMConnector.promptWindowLimit(),
  64. })
  65. .pinnedDocs()
  66. .then((pinnedDocs) => {
  67. pinnedDocs.forEach((doc) => {
  68. const { pageContent, ...metadata } = doc;
  69. pinnedDocIdentifiers.push(sourceIdentifier(doc));
  70. contextTexts.push(doc.pageContent);
  71. sources.push({
  72. text:
  73. pageContent.slice(0, 1_000) +
  74. "...continued on in source document...",
  75. ...metadata,
  76. });
  77. });
  78. });
  79. const vectorSearchResults =
  80. embeddingsCount !== 0
  81. ? await VectorDb.performSimilaritySearch({
  82. namespace: embed.workspace.slug,
  83. input: message,
  84. LLMConnector,
  85. similarityThreshold: embed.workspace?.similarityThreshold,
  86. topN: embed.workspace?.topN,
  87. filterIdentifiers: pinnedDocIdentifiers,
  88. rerank: embed.workspace?.vectorSearchMode === "rerank",
  89. })
  90. : {
  91. contextTexts: [],
  92. sources: [],
  93. message: null,
  94. };
  95. // Failed similarity search if it was run at all and failed.
  96. if (!!vectorSearchResults.message) {
  97. writeResponseChunk(response, {
  98. id: uuid,
  99. type: "abort",
  100. textResponse: null,
  101. sources: [],
  102. close: true,
  103. error: "Failed to connect to vector database provider.",
  104. });
  105. return;
  106. }
  107. const { fillSourceWindow } = require("../helpers/chat");
  108. const filledSources = fillSourceWindow({
  109. nDocs: embed.workspace?.topN || 4,
  110. searchResults: vectorSearchResults.sources,
  111. history: rawHistory,
  112. filterIdentifiers: pinnedDocIdentifiers,
  113. });
  114. // Why does contextTexts get all the info, but sources only get current search?
  115. // This is to give the ability of the LLM to "comprehend" a contextual response without
  116. // populating the Citations under a response with documents the user "thinks" are irrelevant
  117. // due to how we manage backfilling of the context to keep chats with the LLM more correct in responses.
  118. // If a past citation was used to answer the question - that is visible in the history so it logically makes sense
  119. // and does not appear to the user that a new response used information that is otherwise irrelevant for a given prompt.
  120. // TLDR; reduces GitHub issues for "LLM citing document that has no answer in it" while keep answers highly accurate.
  121. contextTexts = [...contextTexts, ...filledSources.contextTexts];
  122. sources = [...sources, ...vectorSearchResults.sources];
  123. // If in query mode and no sources are found in current search or backfilled from history, do not
  124. // let the LLM try to hallucinate a response or use general knowledge
  125. if (chatMode === "query" && contextTexts.length === 0) {
  126. writeResponseChunk(response, {
  127. id: uuid,
  128. type: "textResponse",
  129. textResponse:
  130. embed.workspace?.queryRefusalResponse ??
  131. "There is no relevant information in this workspace to answer your query.",
  132. sources: [],
  133. close: true,
  134. error: null,
  135. });
  136. return;
  137. }
  138. // Compress message to ensure prompt passes token limit with room for response
  139. // and build system messages based on inputs and history.
  140. const messages = await LLMConnector.compressMessages(
  141. {
  142. systemPrompt: chatPrompt(embed.workspace),
  143. userPrompt: message,
  144. contextTexts,
  145. chatHistory,
  146. },
  147. rawHistory
  148. );
  149. // If streaming is not explicitly enabled for connector
  150. // we do regular waiting of a response and send a single chunk.
  151. if (LLMConnector.streamingEnabled() !== true) {
  152. console.log(
  153. `\x1b[31m[STREAMING DISABLED]\x1b[0m Streaming is not available for ${LLMConnector.constructor.name}. Will use regular chat method.`
  154. );
  155. const { textResponse, metrics: performanceMetrics } =
  156. await LLMConnector.getChatCompletion(messages, {
  157. temperature: embed.workspace?.openAiTemp ?? LLMConnector.defaultTemp,
  158. });
  159. completeText = textResponse;
  160. metrics = performanceMetrics;
  161. writeResponseChunk(response, {
  162. uuid,
  163. sources: [],
  164. type: "textResponseChunk",
  165. textResponse: completeText,
  166. close: true,
  167. error: false,
  168. });
  169. } else {
  170. const stream = await LLMConnector.streamGetChatCompletion(messages, {
  171. temperature: embed.workspace?.openAiTemp ?? LLMConnector.defaultTemp,
  172. });
  173. completeText = await LLMConnector.handleStream(response, stream, {
  174. uuid,
  175. sources: [],
  176. });
  177. metrics = stream.metrics;
  178. }
  179. await EmbedChats.new({
  180. embedId: embed.id,
  181. prompt: message,
  182. response: { text: completeText, type: chatMode, sources, metrics },
  183. connection_information: response.locals.connection
  184. ? {
  185. ...response.locals.connection,
  186. username: !!username ? String(username) : null,
  187. }
  188. : { username: !!username ? String(username) : null },
  189. sessionId,
  190. });
  191. return;
  192. }
  193. /**
  194. * @param {string} sessionId the session id of the user from embed widget
  195. * @param {Object} embed the embed config object
  196. * @param {Number} messageLimit the number of messages to return
  197. * @returns {Promise<{rawHistory: import("@prisma/client").embed_chats[], chatHistory: {role: string, content: string, attachments?: Object[]}[]}>
  198. */
  199. async function recentEmbedChatHistory(sessionId, embed, messageLimit = 20) {
  200. const rawHistory = (
  201. await EmbedChats.forEmbedByUser(embed.id, sessionId, messageLimit, {
  202. id: "desc",
  203. })
  204. ).reverse();
  205. return { rawHistory, chatHistory: convertToPromptHistory(rawHistory) };
  206. }
  207. module.exports = {
  208. streamChatWithForEmbed,
  209. };