You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

157 lines
4.5 KiB

11 months ago
  1. const { NativeEmbedder } = require("../../EmbeddingEngines/native");
  2. const {
  3. LLMPerformanceMonitor,
  4. } = require("../../helpers/chat/LLMPerformanceMonitor");
  5. const {
  6. handleDefaultStreamResponseV2,
  7. } = require("../../helpers/chat/responses");
  8. function fireworksAiModels() {
  9. const { MODELS } = require("./models.js");
  10. return MODELS || {};
  11. }
  12. class FireworksAiLLM {
  13. constructor(embedder = null, modelPreference = null) {
  14. if (!process.env.FIREWORKS_AI_LLM_API_KEY)
  15. throw new Error("No FireworksAI API key was set.");
  16. const { OpenAI: OpenAIApi } = require("openai");
  17. this.openai = new OpenAIApi({
  18. baseURL: "https://api.fireworks.ai/inference/v1",
  19. apiKey: process.env.FIREWORKS_AI_LLM_API_KEY ?? null,
  20. });
  21. this.model = modelPreference || process.env.FIREWORKS_AI_LLM_MODEL_PREF;
  22. this.limits = {
  23. history: this.promptWindowLimit() * 0.15,
  24. system: this.promptWindowLimit() * 0.15,
  25. user: this.promptWindowLimit() * 0.7,
  26. };
  27. this.embedder = !embedder ? new NativeEmbedder() : embedder;
  28. this.defaultTemp = 0.7;
  29. }
  30. #appendContext(contextTexts = []) {
  31. if (!contextTexts || !contextTexts.length) return "";
  32. return (
  33. "\nContext:\n" +
  34. contextTexts
  35. .map((text, i) => {
  36. return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`;
  37. })
  38. .join("")
  39. );
  40. }
  41. allModelInformation() {
  42. return fireworksAiModels();
  43. }
  44. streamingEnabled() {
  45. return "streamGetChatCompletion" in this;
  46. }
  47. static promptWindowLimit(modelName) {
  48. const availableModels = fireworksAiModels();
  49. return availableModels[modelName]?.maxLength || 4096;
  50. }
  51. // Ensure the user set a value for the token limit
  52. // and if undefined - assume 4096 window.
  53. promptWindowLimit() {
  54. const availableModels = this.allModelInformation();
  55. return availableModels[this.model]?.maxLength || 4096;
  56. }
  57. async isValidChatCompletionModel(model = "") {
  58. const availableModels = this.allModelInformation();
  59. return availableModels.hasOwnProperty(model);
  60. }
  61. constructPrompt({
  62. systemPrompt = "",
  63. contextTexts = [],
  64. chatHistory = [],
  65. userPrompt = "",
  66. }) {
  67. const prompt = {
  68. role: "system",
  69. content: `${systemPrompt}${this.#appendContext(contextTexts)}`,
  70. };
  71. return [prompt, ...chatHistory, { role: "user", content: userPrompt }];
  72. }
  73. async getChatCompletion(messages = null, { temperature = 0.7 }) {
  74. if (!(await this.isValidChatCompletionModel(this.model)))
  75. throw new Error(
  76. `FireworksAI chat: ${this.model} is not valid for chat completion!`
  77. );
  78. const result = await LLMPerformanceMonitor.measureAsyncFunction(
  79. this.openai.chat.completions.create({
  80. model: this.model,
  81. messages,
  82. temperature,
  83. })
  84. );
  85. if (
  86. !result.output.hasOwnProperty("choices") ||
  87. result.output.choices.length === 0
  88. )
  89. return null;
  90. return {
  91. textResponse: result.output.choices[0].message.content,
  92. metrics: {
  93. prompt_tokens: result.output.usage.prompt_tokens || 0,
  94. completion_tokens: result.output.usage.completion_tokens || 0,
  95. total_tokens: result.output.usage.total_tokens || 0,
  96. outputTps: result.output.usage.completion_tokens / result.duration,
  97. duration: result.duration,
  98. },
  99. };
  100. }
  101. async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
  102. if (!(await this.isValidChatCompletionModel(this.model)))
  103. throw new Error(
  104. `FireworksAI chat: ${this.model} is not valid for chat completion!`
  105. );
  106. const measuredStreamRequest = await LLMPerformanceMonitor.measureStream(
  107. this.openai.chat.completions.create({
  108. model: this.model,
  109. stream: true,
  110. messages,
  111. temperature,
  112. }),
  113. messages,
  114. false
  115. );
  116. return measuredStreamRequest;
  117. }
  118. handleStream(response, stream, responseProps) {
  119. return handleDefaultStreamResponseV2(response, stream, responseProps);
  120. }
  121. // Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
  122. async embedTextInput(textInput) {
  123. return await this.embedder.embedTextInput(textInput);
  124. }
  125. async embedChunks(textChunks = []) {
  126. return await this.embedder.embedChunks(textChunks);
  127. }
  128. async compressMessages(promptArgs = {}, rawHistory = []) {
  129. const { messageArrayCompressor } = require("../../helpers/chat");
  130. const messageArray = this.constructPrompt(promptArgs);
  131. return await messageArrayCompressor(this, messageArray, rawHistory);
  132. }
  133. }
  134. module.exports = {
  135. FireworksAiLLM,
  136. fireworksAiModels,
  137. };