You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

362 lines
12 KiB

11 months ago
  1. const { StringOutputParser } = require("@langchain/core/output_parsers");
  2. const {
  3. writeResponseChunk,
  4. clientAbortedHandler,
  5. formatChatHistory,
  6. } = require("../../helpers/chat/responses");
  7. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  8. const {
  9. LLMPerformanceMonitor,
  10. } = require("../../helpers/chat/LLMPerformanceMonitor");
  11. // Docs: https://js.langchain.com/v0.2/docs/integrations/chat/bedrock_converse
  12. class AWSBedrockLLM {
  13. /**
  14. * These models do not support system prompts
  15. * It is not explicitly stated but it is observed that they do not use the system prompt
  16. * in their responses and will crash when a system prompt is provided.
  17. * We can add more models to this list as we discover them or new models are added.
  18. * We may want to extend this list or make a user-config if using custom bedrock models.
  19. */
  20. noSystemPromptModels = [
  21. "amazon.titan-text-express-v1",
  22. "amazon.titan-text-lite-v1",
  23. "cohere.command-text-v14",
  24. "cohere.command-light-text-v14",
  25. ];
  26. constructor(embedder = null, modelPreference = null) {
  27. if (!process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID)
  28. throw new Error("No AWS Bedrock LLM profile id was set.");
  29. if (!process.env.AWS_BEDROCK_LLM_ACCESS_KEY)
  30. throw new Error("No AWS Bedrock LLM access key was set.");
  31. if (!process.env.AWS_BEDROCK_LLM_REGION)
  32. throw new Error("No AWS Bedrock LLM region was set.");
  33. if (
  34. process.env.AWS_BEDROCK_LLM_CONNECTION_METHOD === "sessionToken" &&
  35. !process.env.AWS_BEDROCK_LLM_SESSION_TOKEN
  36. )
  37. throw new Error(
  38. "No AWS Bedrock LLM session token was set while using session token as the authentication method."
  39. );
  40. this.model =
  41. modelPreference || process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE;
  42. this.limits = {
  43. history: this.promptWindowLimit() * 0.15,
  44. system: this.promptWindowLimit() * 0.15,
  45. user: this.promptWindowLimit() * 0.7,
  46. };
  47. this.embedder = embedder ?? new NativeEmbedder();
  48. this.defaultTemp = 0.7;
  49. this.#log(
  50. `Loaded with model: ${this.model}. Will communicate with AWS Bedrock using ${this.authMethod} authentication.`
  51. );
  52. }
  53. /**
  54. * Get the authentication method for the AWS Bedrock LLM.
  55. * There are only two valid values for this setting - anything else will default to "iam".
  56. * @returns {"iam"|"sessionToken"}
  57. */
  58. get authMethod() {
  59. const method = process.env.AWS_BEDROCK_LLM_CONNECTION_METHOD || "iam";
  60. if (!["iam", "sessionToken"].includes(method)) return "iam";
  61. return method;
  62. }
  63. #bedrockClient({ temperature = 0.7 }) {
  64. const { ChatBedrockConverse } = require("@langchain/aws");
  65. return new ChatBedrockConverse({
  66. model: this.model,
  67. region: process.env.AWS_BEDROCK_LLM_REGION,
  68. credentials: {
  69. accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
  70. secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
  71. ...(this.authMethod === "sessionToken"
  72. ? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
  73. : {}),
  74. },
  75. temperature,
  76. });
  77. }
  78. // For streaming we use Langchain's wrapper to handle weird chunks
  79. // or otherwise absorb headaches that can arise from Bedrock models
  80. #convertToLangchainPrototypes(chats = []) {
  81. const {
  82. HumanMessage,
  83. SystemMessage,
  84. AIMessage,
  85. } = require("@langchain/core/messages");
  86. const langchainChats = [];
  87. const roleToMessageMap = {
  88. system: SystemMessage,
  89. user: HumanMessage,
  90. assistant: AIMessage,
  91. };
  92. for (const chat of chats) {
  93. if (!roleToMessageMap.hasOwnProperty(chat.role)) continue;
  94. // When a model does not support system prompts, we need to handle it.
  95. // We will add a new message that simulates the system prompt via a user message and AI response.
  96. // This will allow the model to respond without crashing but we can still inject context.
  97. if (
  98. this.noSystemPromptModels.includes(this.model) &&
  99. chat.role === "system"
  100. ) {
  101. this.#log(
  102. `Model does not support system prompts! Simulating system prompt via Human/AI message pairs.`
  103. );
  104. langchainChats.push(new HumanMessage({ content: chat.content }));
  105. langchainChats.push(new AIMessage({ content: "Okay." }));
  106. continue;
  107. }
  108. const MessageClass = roleToMessageMap[chat.role];
  109. langchainChats.push(new MessageClass({ content: chat.content }));
  110. }
  111. return langchainChats;
  112. }
  113. #appendContext(contextTexts = []) {
  114. if (!contextTexts || !contextTexts.length) return "";
  115. return (
  116. "\nContext:\n" +
  117. contextTexts
  118. .map((text, i) => {
  119. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  120. })
  121. .join("")
  122. );
  123. }
  124. #log(text, ...args) {
  125. console.log(`\x1b[32m[AWSBedrock]\x1b[0m ${text}`, ...args);
  126. }
  127. streamingEnabled() {
  128. return "streamGetChatCompletion" in this;
  129. }
  130. static promptWindowLimit(_modelName) {
  131. const limit = process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT || 8191;
  132. if (!limit || isNaN(Number(limit)))
  133. throw new Error("No valid token context limit was set.");
  134. return Number(limit);
  135. }
  136. // Ensure the user set a value for the token limit
  137. // and if undefined - assume 4096 window.
  138. promptWindowLimit() {
  139. const limit = process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT || 8191;
  140. if (!limit || isNaN(Number(limit)))
  141. throw new Error("No valid token context limit was set.");
  142. return Number(limit);
  143. }
  144. async isValidChatCompletionModel(_ = "") {
  145. return true;
  146. }
  147. /**
  148. * Generates appropriate content array for a message + attachments.
  149. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
  150. * @returns {string|object[]}
  151. */
  152. #generateContent({ userPrompt, attachments = [] }) {
  153. if (!attachments.length) {
  154. return { content: userPrompt };
  155. }
  156. const content = [{ type: "text", text: userPrompt }];
  157. for (let attachment of attachments) {
  158. content.push({
  159. type: "image_url",
  160. image_url: attachment.contentString,
  161. });
  162. }
  163. return { content: content.flat() };
  164. }
  165. /**
  166. * Construct the user prompt for this model.
  167. * @param {{attachments: import("../../helpers").Attachment[]}} param0
  168. * @returns
  169. */
  170. constructPrompt({
  171. systemPrompt = "",
  172. contextTexts = [],
  173. chatHistory = [],
  174. userPrompt = "",
  175. attachments = [],
  176. }) {
  177. // AWS Mistral models do not support system prompts
  178. if (this.model.startsWith("mistral"))
  179. return [
  180. ...formatChatHistory(chatHistory, this.#generateContent, "spread"),
  181. {
  182. role: "user",
  183. ...this.#generateContent({ userPrompt, attachments }),
  184. },
  185. ];
  186. const prompt = {
  187. role: "system",
  188. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  189. };
  190. return [
  191. prompt,
  192. ...formatChatHistory(chatHistory, this.#generateContent, "spread"),
  193. {
  194. role: "user",
  195. ...this.#generateContent({ userPrompt, attachments }),
  196. },
  197. ];
  198. }
  199. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  200. const model = this.#bedrockClient({ temperature });
  201. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  202. model
  203. .pipe(new StringOutputParser())
  204. .invoke(this.#convertToLangchainPrototypes(messages))
  205. .catch((e) => {
  206. throw new Error(
  207. `AWSBedrock::getChatCompletion failed to communicate with Bedrock client. ${e.message}`
  208. );
  209. })
  210. );
  211. if (!result.output || result.output.length === 0) return null;
  212. // Langchain does not return the usage metrics in the response so we estimate them
  213. const promptTokens = LLMPerformanceMonitor.countTokens(messages);
  214. const completionTokens = LLMPerformanceMonitor.countTokens([
  215. { content: result.output },
  216. ]);
  217. return {
  218. textResponse: result.output,
  219. metrics: {
  220. prompt_tokens: promptTokens,
  221. completion_tokens: completionTokens,
  222. total_tokens: promptTokens + completionTokens,
  223. outputTps: completionTokens / result.duration,
  224. duration: result.duration,
  225. },
  226. };
  227. }
  228. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  229. const model = this.#bedrockClient({ temperature });
  230. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  231. model
  232. .pipe(new StringOutputParser())
  233. .stream(this.#convertToLangchainPrototypes(messages)),
  234. messages
  235. );
  236. return measuredStreamRequest;
  237. }
  238. /**
  239. * Handles the stream response from the AWS Bedrock API.
  240. * Bedrock does not support usage metrics in the stream response so we need to estimate them.
  241. * @param {Object} response - the response object
  242. * @param {import('../../helpers/chat/LLMPerformanceMonitor').MonitoredStream} stream - the stream response from the AWS Bedrock API w/tracking
  243. * @param {Object} responseProps - the response properties
  244. * @returns {Promise<string>}
  245. */
  246. handleStream(response, stream, responseProps) {
  247. const { uuid = uuidv4(), sources = [] } = responseProps;
  248. return new Promise(async (resolve) => {
  249. let fullText = "";
  250. let usage = {
  251. completion_tokens: 0,
  252. };
  253. // Establish listener to early-abort a streaming response
  254. // in case things go sideways or the user does not like the response.
  255. // We preserve the generated text but continue as if chat was completed
  256. // to preserve previously generated content.
  257. const handleAbort = () => {
  258. stream?.endMeasurement(usage);
  259. clientAbortedHandler(resolve, fullText);
  260. };
  261. response.on("close", handleAbort);
  262. try {
  263. for await (const chunk of stream) {
  264. if (chunk === undefined)
  265. throw new Error(
  266. "Stream returned undefined chunk. Aborting reply - check model provider logs."
  267. );
  268. const content = chunk.hasOwnProperty("content")
  269. ? chunk.content
  270. : chunk;
  271. fullText += content;
  272. if (!!content) usage.completion_tokens++; // Dont count empty chunks
  273. writeResponseChunk(response, {
  274. uuid,
  275. sources: [],
  276. type: "textResponseChunk",
  277. textResponse: content,
  278. close: false,
  279. error: false,
  280. });
  281. }
  282. writeResponseChunk(response, {
  283. uuid,
  284. sources,
  285. type: "textResponseChunk",
  286. textResponse: "",
  287. close: true,
  288. error: false,
  289. });
  290. response.removeListener("close", handleAbort);
  291. stream?.endMeasurement(usage);
  292. resolve(fullText);
  293. } catch (error) {
  294. writeResponseChunk(response, {
  295. uuid,
  296. sources: [],
  297. type: "textResponseChunk",
  298. textResponse: "",
  299. close: true,
  300. error: `AWSBedrock:streaming - could not stream chat. ${
  301. error?.cause ?? error.message
  302. }`,
  303. });
  304. response.removeListener("close", handleAbort);
  305. stream?.endMeasurement(usage);
  306. }
  307. });
  308. }
  309. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  310. async embedTextInput(textInput) {
  311. return await this.embedder.embedTextInput(textInput);
  312. }
  313. async embedChunks(textChunks = []) {
  314. return await this.embedder.embedChunks(textChunks);
  315. }
  316. async compressMessages(promptArgs = {}, rawHistory = []) {
  317. const { messageArrayCompressor } = require("../../helpers/chat");
  318. const messageArray = this.constructPrompt(promptArgs);
  319. return await messageArrayCompressor(this, messageArray, rawHistory);
  320. }
  321. }
  322. module.exports = {
  323. AWSBedrockLLM,
  324. };