You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

241 lines
8.4 KiB

11 months ago
  1. const path = require("path");
  2. const fs = require("fs");
  3. class NativeEmbeddingReranker {
  4. static #model = null;
  5. static #tokenizer = null;
  6. static #transformers = null;
  7. // This is a folder that Mintplex Labs hosts for those who cannot capture the HF model download
  8. // endpoint for various reasons. This endpoint is not guaranteed to be active or maintained
  9. // and may go offline at any time at Mintplex Labs's discretion.
  10. #fallbackHost = "https://cdn.anythingllm.com/support/models/";
  11. constructor() {
  12. // An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s)
  13. // Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s)
  14. this.model = "Xenova/ms-marco-MiniLM-L-6-v2";
  15. this.cacheDir = path.resolve(
  16. process.env.STORAGE_DIR
  17. ? path.resolve(process.env.STORAGE_DIR, `models`)
  18. : path.resolve(__dirname, `../../../storage/models`)
  19. );
  20. this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/"));
  21. // Make directory when it does not exist in existing installations
  22. if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);
  23. this.modelDownloaded = fs.existsSync(
  24. path.resolve(this.cacheDir, this.model)
  25. );
  26. this.log("Initialized");
  27. }
  28. log(text, ...args) {
  29. console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args);
  30. }
  31. /**
  32. * This function will return the host of the current reranker suite.
  33. * If the reranker suite is not initialized, it will return the default HF host.
  34. * @returns {string} The host of the current reranker suite.
  35. */
  36. get host() {
  37. if (!NativeEmbeddingReranker.#transformers) return "https://huggingface.co";
  38. try {
  39. return new URL(NativeEmbeddingReranker.#transformers.env.remoteHost).host;
  40. } catch (e) {
  41. return this.#fallbackHost;
  42. }
  43. }
  44. /**
  45. * This function will preload the reranker suite and tokenizer.
  46. * This is useful for reducing the latency of the first rerank call and pre-downloading the models and such
  47. * to avoid having to wait for the models to download on the first rerank call.
  48. */
  49. async preload() {
  50. try {
  51. this.log(`Preloading reranker suite...`);
  52. await this.initClient();
  53. this.log(
  54. `Preloaded reranker suite. Reranking is available as a service now.`
  55. );
  56. return;
  57. } catch (e) {
  58. console.error(e);
  59. this.log(
  60. `Failed to preload reranker suite. Reranking will be available on the first rerank call.`
  61. );
  62. return;
  63. }
  64. }
  65. async initClient() {
  66. if (NativeEmbeddingReranker.#transformers) {
  67. this.log(`Reranker suite already initialized - reusing.`);
  68. return;
  69. }
  70. await import("@xenova/transformers").then(
  71. async ({ AutoModelForSequenceClassification, AutoTokenizer, env }) => {
  72. this.log(`Loading reranker suite...`);
  73. NativeEmbeddingReranker.#transformers = {
  74. AutoModelForSequenceClassification,
  75. AutoTokenizer,
  76. env,
  77. };
  78. // Attempt to load the model and tokenizer in this order:
  79. // 1. From local file system cache
  80. // 2. Download and cache from remote host (hf.co)
  81. // 3. Download and cache from fallback host (cdn.anythingllm.com)
  82. await this.#getPreTrainedModel();
  83. await this.#getPreTrainedTokenizer();
  84. }
  85. );
  86. return;
  87. }
  88. /**
  89. * This function will load the model from the local file system cache, or download and cache it from the remote host.
  90. * If the model is not found in the local file system cache, it will download and cache it from the remote host.
  91. * If the model is not found in the remote host, it will download and cache it from the fallback host.
  92. * @returns {Promise<any>} The loaded model.
  93. */
  94. async #getPreTrainedModel() {
  95. if (NativeEmbeddingReranker.#model) {
  96. this.log(`Loading model from singleton...`);
  97. return NativeEmbeddingReranker.#model;
  98. }
  99. try {
  100. const model =
  101. await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained(
  102. this.model,
  103. {
  104. progress_callback: (p) => {
  105. if (!this.modelDownloaded && p.status === "progress") {
  106. this.log(
  107. `[${this.host}] Loading model ${this.model}... ${p?.progress}%`
  108. );
  109. }
  110. },
  111. cache_dir: this.cacheDir,
  112. }
  113. );
  114. this.log(`Loaded model ${this.model}`);
  115. NativeEmbeddingReranker.#model = model;
  116. return model;
  117. } catch (e) {
  118. this.log(
  119. `Failed to load model ${this.model} from ${this.host}.`,
  120. e.message,
  121. e.stack
  122. );
  123. if (
  124. NativeEmbeddingReranker.#transformers.env.remoteHost ===
  125. this.#fallbackHost
  126. ) {
  127. this.log(`Failed to load model ${this.model} from fallback host.`);
  128. throw e;
  129. }
  130. this.log(`Falling back to fallback host. ${this.#fallbackHost}`);
  131. NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost;
  132. NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/";
  133. return await this.#getPreTrainedModel();
  134. }
  135. }
  136. /**
  137. * This function will load the tokenizer from the local file system cache, or download and cache it from the remote host.
  138. * If the tokenizer is not found in the local file system cache, it will download and cache it from the remote host.
  139. * If the tokenizer is not found in the remote host, it will download and cache it from the fallback host.
  140. * @returns {Promise<any>} The loaded tokenizer.
  141. */
  142. async #getPreTrainedTokenizer() {
  143. if (NativeEmbeddingReranker.#tokenizer) {
  144. this.log(`Loading tokenizer from singleton...`);
  145. return NativeEmbeddingReranker.#tokenizer;
  146. }
  147. try {
  148. const tokenizer =
  149. await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained(
  150. this.model,
  151. {
  152. progress_callback: (p) => {
  153. if (!this.modelDownloaded && p.status === "progress") {
  154. this.log(
  155. `[${this.host}] Loading tokenizer ${this.model}... ${p?.progress}%`
  156. );
  157. }
  158. },
  159. cache_dir: this.cacheDir,
  160. }
  161. );
  162. this.log(`Loaded tokenizer ${this.model}`);
  163. NativeEmbeddingReranker.#tokenizer = tokenizer;
  164. return tokenizer;
  165. } catch (e) {
  166. this.log(
  167. `Failed to load tokenizer ${this.model} from ${this.host}.`,
  168. e.message,
  169. e.stack
  170. );
  171. if (
  172. NativeEmbeddingReranker.#transformers.env.remoteHost ===
  173. this.#fallbackHost
  174. ) {
  175. this.log(`Failed to load tokenizer ${this.model} from fallback host.`);
  176. throw e;
  177. }
  178. this.log(`Falling back to fallback host. ${this.#fallbackHost}`);
  179. NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost;
  180. NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/";
  181. return await this.#getPreTrainedTokenizer();
  182. }
  183. }
  184. /**
  185. * Reranks a list of documents based on the query.
  186. * @param {string} query - The query to rerank the documents against.
  187. * @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search.
  188. * @param {Object} options - The options for the reranking.
  189. * @param {number} options.topK - The number of top documents to return.
  190. * @returns {Promise<any[]>} - The reranked list of documents.
  191. */
  192. async rerank(query, documents, options = { topK: 4 }) {
  193. await this.initClient();
  194. const model = NativeEmbeddingReranker.#model;
  195. const tokenizer = NativeEmbeddingReranker.#tokenizer;
  196. const start = Date.now();
  197. this.log(`Reranking ${documents.length} documents...`);
  198. const inputs = tokenizer(new Array(documents.length).fill(query), {
  199. text_pair: documents.map((doc) => doc.text),
  200. padding: true,
  201. truncation: true,
  202. });
  203. const { logits } = await model(inputs);
  204. const reranked = logits
  205. .sigmoid()
  206. .tolist()
  207. .map(([score], i) => ({
  208. rerank_corpus_id: i,
  209. rerank_score: score,
  210. ...documents[i],
  211. }))
  212. .sort((a, b) => b.rerank_score - a.rerank_score)
  213. .slice(0, options.topK);
  214. this.log(
  215. `Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms`
  216. );
  217. return reranked;
  218. }
  219. }
  220. module.exports = {
  221. NativeEmbeddingReranker,
  222. };