You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
2.9 KiB

11 months ago
  1. const { toChunks, maximumChunkLength } = require("../../helpers");
  2. class LiteLLMEmbedder {
  3. constructor() {
  4. const { OpenAI: OpenAIApi } = require("openai");
  5. if (!process.env.LITE_LLM_BASE_PATH)
  6. throw new Error(
  7. "LiteLLM must have a valid base path to use for the api."
  8. );
  9. this.basePath = process.env.LITE_LLM_BASE_PATH;
  10. this.openai = new OpenAIApi({
  11. baseURL: this.basePath,
  12. apiKey: process.env.LITE_LLM_API_KEY ?? null,
  13. });
  14. this.model = process.env.EMBEDDING_MODEL_PREF || "text-embedding-ada-002";
  15. // Limit of how many strings we can process in a single pass to stay with resource or network limits
  16. this.maxConcurrentChunks = 500;
  17. this.embeddingMaxChunkLength = maximumChunkLength();
  18. }
  19. async embedTextInput(textInput) {
  20. const result = await this.embedChunks(
  21. Array.isArray(textInput) ? textInput : [textInput]
  22. );
  23. return result?.[0] || [];
  24. }
  25. async embedChunks(textChunks = []) {
  26. // Because there is a hard POST limit on how many chunks can be sent at once to LiteLLM (~8mb)
  27. // we concurrently execute each max batch of text chunks possible.
  28. // Refer to constructor maxConcurrentChunks for more info.
  29. const embeddingRequests = [];
  30. for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) {
  31. embeddingRequests.push(
  32. new Promise((resolve) => {
  33. this.openai.embeddings
  34. .create({
  35. model: this.model,
  36. input: chunk,
  37. })
  38. .then((result) => {
  39. resolve({ data: result?.data, error: null });
  40. })
  41. .catch((e) => {
  42. e.type =
  43. e?.response?.data?.error?.code ||
  44. e?.response?.status ||
  45. "failed_to_embed";
  46. e.message = e?.response?.data?.error?.message || e.message;
  47. resolve({ data: [], error: e });
  48. });
  49. })
  50. );
  51. }
  52. const { data = [], error = null } = await Promise.all(
  53. embeddingRequests
  54. ).then((results) => {
  55. // If any errors were returned from LiteLLM abort the entire sequence because the embeddings
  56. // will be incomplete.
  57. const errors = results
  58. .filter((res) => !!res.error)
  59. .map((res) => res.error)
  60. .flat();
  61. if (errors.length > 0) {
  62. let uniqueErrors = new Set();
  63. errors.map((error) =>
  64. uniqueErrors.add(`[${error.type}]: ${error.message}`)
  65. );
  66. return {
  67. data: [],
  68. error: Array.from(uniqueErrors).join(", "),
  69. };
  70. }
  71. return {
  72. data: results.map((res) => res?.data || []).flat(),
  73. error: null,
  74. };
  75. });
  76. if (!!error) throw new Error(`LiteLLM Failed to embed: ${error}`);
  77. return data.length > 0 &&
  78. data.every((embd) => embd.hasOwnProperty("embedding"))
  79. ? data.map((embd) => embd.embedding)
  80. : null;
  81. }
  82. }
  83. module.exports = {
  84. LiteLLMEmbedder,
  85. };