You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

86 lines
2.5 KiB

11 months ago
  1. const { toChunks } = require("../../helpers");
  2. class CohereEmbedder {
  3. constructor() {
  4. if (!process.env.COHERE_API_KEY)
  5. throw new Error("No Cohere API key was set.");
  6. const { CohereClient } = require("cohere-ai");
  7. const cohere = new CohereClient({
  8. token: process.env.COHERE_API_KEY,
  9. });
  10. this.cohere = cohere;
  11. this.model = process.env.EMBEDDING_MODEL_PREF || "embed-english-v3.0";
  12. this.inputType = "search_document";
  13. // Limit of how many strings we can process in a single pass to stay with resource or network limits
  14. this.maxConcurrentChunks = 96; // Cohere's limit per request is 96
  15. this.embeddingMaxChunkLength = 1945; // https://docs.cohere.com/docs/embed-2 - assume a token is roughly 4 letters with some padding
  16. }
  17. async embedTextInput(textInput) {
  18. this.inputType = "search_query";
  19. const result = await this.embedChunks([textInput]);
  20. return result?.[0] || [];
  21. }
  22. async embedChunks(textChunks = []) {
  23. const embeddingRequests = [];
  24. this.inputType = "search_document";
  25. for (const chunk of toChunks(textChunks, this.maxConcurrentChunks)) {
  26. embeddingRequests.push(
  27. new Promise((resolve) => {
  28. this.cohere
  29. .embed({
  30. texts: chunk,
  31. model: this.model,
  32. inputType: this.inputType,
  33. })
  34. .then((res) => {
  35. resolve({ data: res.embeddings, error: null });
  36. })
  37. .catch((e) => {
  38. e.type =
  39. e?.response?.data?.error?.code ||
  40. e?.response?.status ||
  41. "failed_to_embed";
  42. e.message = e?.response?.data?.error?.message || e.message;
  43. resolve({ data: [], error: e });
  44. });
  45. })
  46. );
  47. }
  48. const { data = [], error = null } = await Promise.all(
  49. embeddingRequests
  50. ).then((results) => {
  51. const errors = results
  52. .filter((res) => !!res.error)
  53. .map((res) => res.error)
  54. .flat();
  55. if (errors.length > 0) {
  56. let uniqueErrors = new Set();
  57. errors.map((error) =>
  58. uniqueErrors.add(`[${error.type}]: ${error.message}`)
  59. );
  60. return { data: [], error: Array.from(uniqueErrors).join(", ") };
  61. }
  62. return {
  63. data: results.map((res) => res?.data || []).flat(),
  64. error: null,
  65. };
  66. });
  67. if (!!error) throw new Error(`Cohere Failed to embed: ${error}`);
  68. return data.length > 0 ? data : null;
  69. }
  70. }
  71. module.exports = {
  72. CohereEmbedder,
  73. };