You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

478 lines
16 KiB

11 months ago
  1. const lancedb = require("@lancedb/lancedb");
  2. const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
  3. const { TextSplitter } = require("../../TextSplitter");
  4. const { SystemSettings } = require("../../../models/systemSettings");
  5. const { storeVectorResult, cachedVectorInformation } = require("../../files");
  6. const { v4: uuidv4 } = require("uuid");
  7. const { sourceIdentifier } = require("../../chats");
  8. const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native");
  9. /**
  10. * LancedDB Client connection object
  11. * @typedef {import('@lancedb/lancedb').Connection} LanceClient
  12. */
  13. const LanceDb = {
  14. uri: `${
  15. !!process.env.STORAGE_DIR ? `${process.env.STORAGE_DIR}/` : "./storage/"
  16. }lancedb`,
  17. name: "LanceDb",
  18. /** @returns {Promise<{client: LanceClient}>} */
  19. connect: async function () {
  20. if (process.env.VECTOR_DB !== "lancedb")
  21. throw new Error("LanceDB::Invalid ENV settings");
  22. const client = await lancedb.connect(this.uri);
  23. return { client };
  24. },
  25. distanceToSimilarity: function (distance = null) {
  26. if (distance === null || typeof distance !== "number") return 0.0;
  27. if (distance >= 1.0) return 1;
  28. if (distance < 0) return 1 - Math.abs(distance);
  29. return 1 - distance;
  30. },
  31. heartbeat: async function () {
  32. await this.connect();
  33. return { heartbeat: Number(new Date()) };
  34. },
  35. tables: async function () {
  36. const { client } = await this.connect();
  37. return await client.tableNames();
  38. },
  39. totalVectors: async function () {
  40. const { client } = await this.connect();
  41. const tables = await client.tableNames();
  42. let count = 0;
  43. for (const tableName of tables) {
  44. const table = await client.openTable(tableName);
  45. count += await table.countRows();
  46. }
  47. return count;
  48. },
  49. namespaceCount: async function (_namespace = null) {
  50. const { client } = await this.connect();
  51. const exists = await this.namespaceExists(client, _namespace);
  52. if (!exists) return 0;
  53. const table = await client.openTable(_namespace);
  54. return (await table.countRows()) || 0;
  55. },
  56. /**
  57. * Performs a SimilaritySearch + Reranking on a namespace.
  58. * @param {Object} params - The parameters for the rerankedSimilarityResponse.
  59. * @param {Object} params.client - The vectorDB client.
  60. * @param {string} params.namespace - The namespace to search in.
  61. * @param {string} params.query - The query to search for (plain text).
  62. * @param {number[]} params.queryVector - The vector of the query.
  63. * @param {number} params.similarityThreshold - The threshold for similarity.
  64. * @param {number} params.topN - the number of results to return from this process.
  65. * @param {string[]} params.filterIdentifiers - The identifiers of the documents to filter out.
  66. * @returns
  67. */
  68. rerankedSimilarityResponse: async function ({
  69. client,
  70. namespace,
  71. query,
  72. queryVector,
  73. topN = 4,
  74. similarityThreshold = 0.25,
  75. filterIdentifiers = [],
  76. }) {
  77. const reranker = new NativeEmbeddingReranker();
  78. const collection = await client.openTable(namespace);
  79. const totalEmbeddings = await this.namespaceCount(namespace);
  80. const result = {
  81. contextTexts: [],
  82. sourceDocuments: [],
  83. scores: [],
  84. };
  85. /**
  86. * For reranking, we want to work with a larger number of results than the topN.
  87. * This is because the reranker can only rerank the results it it given and we dont auto-expand the results.
  88. * We want to give the reranker a larger number of results to work with.
  89. *
  90. * However, we cannot make this boundless as reranking is expensive and time consuming.
  91. * So we limit the number of results to a maximum of 50 and a minimum of 10.
  92. * This is a good balance between the number of results to rerank and the cost of reranking
  93. * and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware.
  94. *
  95. * Benchmarks:
  96. * On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec
  97. */
  98. const searchLimit = Math.max(
  99. 10,
  100. Math.min(50, Math.ceil(totalEmbeddings * 0.1))
  101. );
  102. const vectorSearchResults = await collection
  103. .vectorSearch(queryVector)
  104. .distanceType("cosine")
  105. .limit(searchLimit)
  106. .toArray();
  107. await reranker
  108. .rerank(query, vectorSearchResults, { topK: topN })
  109. .then((rerankResults) => {
  110. rerankResults.forEach((item) => {
  111. if (this.distanceToSimilarity(item._distance) < similarityThreshold)
  112. return;
  113. const { vector: _, ...rest } = item;
  114. if (filterIdentifiers.includes(sourceIdentifier(rest))) {
  115. console.log(
  116. "LanceDB: A source was filtered from context as it's parent document is pinned."
  117. );
  118. return;
  119. }
  120. const score =
  121. item?.rerank_score || this.distanceToSimilarity(item._distance);
  122. result.contextTexts.push(rest.text);
  123. result.sourceDocuments.push({
  124. ...rest,
  125. score,
  126. });
  127. result.scores.push(score);
  128. });
  129. })
  130. .catch((e) => {
  131. console.error(e);
  132. console.error("LanceDB::rerankedSimilarityResponse", e.message);
  133. });
  134. return result;
  135. },
  136. /**
  137. * Performs a SimilaritySearch on a give LanceDB namespace.
  138. * @param {Object} params
  139. * @param {LanceClient} params.client
  140. * @param {string} params.namespace
  141. * @param {number[]} params.queryVector
  142. * @param {number} params.similarityThreshold
  143. * @param {number} params.topN
  144. * @param {string[]} params.filterIdentifiers
  145. * @returns
  146. */
  147. similarityResponse: async function ({
  148. client,
  149. namespace,
  150. queryVector,
  151. similarityThreshold = 0.25,
  152. topN = 4,
  153. filterIdentifiers = [],
  154. }) {
  155. const collection = await client.openTable(namespace);
  156. const result = {
  157. contextTexts: [],
  158. sourceDocuments: [],
  159. scores: [],
  160. };
  161. const response = await collection
  162. .vectorSearch(queryVector)
  163. .distanceType("cosine")
  164. .limit(topN)
  165. .toArray();
  166. response.forEach((item) => {
  167. if (this.distanceToSimilarity(item._distance) < similarityThreshold)
  168. return;
  169. const { vector: _, ...rest } = item;
  170. if (filterIdentifiers.includes(sourceIdentifier(rest))) {
  171. console.log(
  172. "LanceDB: A source was filtered from context as it's parent document is pinned."
  173. );
  174. return;
  175. }
  176. result.contextTexts.push(rest.text);
  177. result.sourceDocuments.push({
  178. ...rest,
  179. score: this.distanceToSimilarity(item._distance),
  180. });
  181. result.scores.push(this.distanceToSimilarity(item._distance));
  182. });
  183. return result;
  184. },
  185. /**
  186. *
  187. * @param {LanceClient} client
  188. * @param {string} namespace
  189. * @returns
  190. */
  191. namespace: async function (client, namespace = null) {
  192. if (!namespace) throw new Error("No namespace value provided.");
  193. const collection = await client.openTable(namespace).catch(() => false);
  194. if (!collection) return null;
  195. return {
  196. ...collection,
  197. };
  198. },
  199. /**
  200. *
  201. * @param {LanceClient} client
  202. * @param {number[]} data
  203. * @param {string} namespace
  204. * @returns
  205. */
  206. updateOrCreateCollection: async function (client, data = [], namespace) {
  207. const hasNamespace = await this.hasNamespace(namespace);
  208. if (hasNamespace) {
  209. const collection = await client.openTable(namespace);
  210. await collection.add(data);
  211. return true;
  212. }
  213. await client.createTable(namespace, data);
  214. return true;
  215. },
  216. hasNamespace: async function (namespace = null) {
  217. if (!namespace) return false;
  218. const { client } = await this.connect();
  219. const exists = await this.namespaceExists(client, namespace);
  220. return exists;
  221. },
  222. /**
  223. *
  224. * @param {LanceClient} client
  225. * @param {string} namespace
  226. * @returns
  227. */
  228. namespaceExists: async function (client, namespace = null) {
  229. if (!namespace) throw new Error("No namespace value provided.");
  230. const collections = await client.tableNames();
  231. return collections.includes(namespace);
  232. },
  233. /**
  234. *
  235. * @param {LanceClient} client
  236. * @param {string} namespace
  237. * @returns
  238. */
  239. deleteVectorsInNamespace: async function (client, namespace = null) {
  240. await client.dropTable(namespace);
  241. return true;
  242. },
  243. deleteDocumentFromNamespace: async function (namespace, docId) {
  244. const { client } = await this.connect();
  245. const exists = await this.namespaceExists(client, namespace);
  246. if (!exists) {
  247. console.error(
  248. `LanceDB:deleteDocumentFromNamespace - namespace ${namespace} does not exist.`
  249. );
  250. return;
  251. }
  252. const { DocumentVectors } = require("../../../models/vectors");
  253. const table = await client.openTable(namespace);
  254. const vectorIds = (await DocumentVectors.where({ docId })).map(
  255. (record) => record.vectorId
  256. );
  257. if (vectorIds.length === 0) return;
  258. await table.delete(`id IN (${vectorIds.map((v) => `'${v}'`).join(",")})`);
  259. return true;
  260. },
  261. addDocumentToNamespace: async function (
  262. namespace,
  263. documentData = {},
  264. fullFilePath = null,
  265. skipCache = false
  266. ) {
  267. const { DocumentVectors } = require("../../../models/vectors");
  268. try {
  269. const { pageContent, docId, ...metadata } = documentData;
  270. if (!pageContent || pageContent.length == 0) return false;
  271. console.log("Adding new vectorized document into namespace", namespace);
  272. if (!skipCache) {
  273. const cacheResult = await cachedVectorInformation(fullFilePath);
  274. if (cacheResult.exists) {
  275. const { client } = await this.connect();
  276. const { chunks } = cacheResult;
  277. const documentVectors = [];
  278. const submissions = [];
  279. for (const chunk of chunks) {
  280. chunk.forEach((chunk) => {
  281. const id = uuidv4();
  282. const { id: _id, ...metadata } = chunk.metadata;
  283. documentVectors.push({ docId, vectorId: id });
  284. submissions.push({ id: id, vector: chunk.values, ...metadata });
  285. });
  286. }
  287. await this.updateOrCreateCollection(client, submissions, namespace);
  288. await DocumentVectors.bulkInsert(documentVectors);
  289. return { vectorized: true, error: null };
  290. }
  291. }
  292. // If we are here then we are going to embed and store a novel document.
  293. // We have to do this manually as opposed to using LangChains `xyz.fromDocuments`
  294. // because we then cannot atomically control our namespace to granularly find/remove documents
  295. // from vectordb.
  296. const EmbedderEngine = getEmbeddingEngineSelection();
  297. const textSplitter = new TextSplitter({
  298. chunkSize: TextSplitter.determineMaxChunkSize(
  299. await SystemSettings.getValueOrFallback({
  300. label: "text_splitter_chunk_size",
  301. }),
  302. EmbedderEngine?.embeddingMaxChunkLength
  303. ),
  304. chunkOverlap: await SystemSettings.getValueOrFallback(
  305. { label: "text_splitter_chunk_overlap" },
  306. 20
  307. ),
  308. chunkHeaderMeta: TextSplitter.buildHeaderMeta(metadata),
  309. });
  310. const textChunks = await textSplitter.splitText(pageContent);
  311. console.log("Chunks created from document:", textChunks.length);
  312. const documentVectors = [];
  313. const vectors = [];
  314. const submissions = [];
  315. const vectorValues = await EmbedderEngine.embedChunks(textChunks);
  316. if (!!vectorValues && vectorValues.length > 0) {
  317. for (const [i, vector] of vectorValues.entries()) {
  318. const vectorRecord = {
  319. id: uuidv4(),
  320. values: vector,
  321. // [DO NOT REMOVE]
  322. // LangChain will be unable to find your text if you embed manually and dont include the `text` key.
  323. // https://github.com/hwchase17/langchainjs/blob/2def486af734c0ca87285a48f1a04c057ab74bdf/langchain/src/vectorstores/pinecone.ts#L64
  324. metadata: { ...metadata, text: textChunks[i] },
  325. };
  326. vectors.push(vectorRecord);
  327. submissions.push({
  328. ...vectorRecord.metadata,
  329. id: vectorRecord.id,
  330. vector: vectorRecord.values,
  331. });
  332. documentVectors.push({ docId, vectorId: vectorRecord.id });
  333. }
  334. } else {
  335. throw new Error(
  336. "Could not embed document chunks! This document will not be recorded."
  337. );
  338. }
  339. if (vectors.length > 0) {
  340. const chunks = [];
  341. for (const chunk of toChunks(vectors, 500)) chunks.push(chunk);
  342. console.log("Inserting vectorized chunks into LanceDB collection.");
  343. const { client } = await this.connect();
  344. await this.updateOrCreateCollection(client, submissions, namespace);
  345. await storeVectorResult(chunks, fullFilePath);
  346. }
  347. await DocumentVectors.bulkInsert(documentVectors);
  348. return { vectorized: true, error: null };
  349. } catch (e) {
  350. console.error("addDocumentToNamespace", e.message);
  351. return { vectorized: false, error: e.message };
  352. }
  353. },
  354. performSimilaritySearch: async function ({
  355. namespace = null,
  356. input = "",
  357. LLMConnector = null,
  358. similarityThreshold = 0.25,
  359. topN = 4,
  360. filterIdentifiers = [],
  361. rerank = false,
  362. }) {
  363. if (!namespace || !input || !LLMConnector)
  364. throw new Error("Invalid request to performSimilaritySearch.");
  365. const { client } = await this.connect();
  366. if (!(await this.namespaceExists(client, namespace))) {
  367. return {
  368. contextTexts: [],
  369. sources: [],
  370. message: "Invalid query - no documents found for workspace!",
  371. };
  372. }
  373. const queryVector = await LLMConnector.embedTextInput(input);
  374. const result = rerank
  375. ? await this.rerankedSimilarityResponse({
  376. client,
  377. namespace,
  378. query: input,
  379. queryVector,
  380. similarityThreshold,
  381. topN,
  382. filterIdentifiers,
  383. })
  384. : await this.similarityResponse({
  385. client,
  386. namespace,
  387. queryVector,
  388. similarityThreshold,
  389. topN,
  390. filterIdentifiers,
  391. });
  392. const { contextTexts, sourceDocuments } = result;
  393. const sources = sourceDocuments.map((metadata, i) => {
  394. return { metadata: { ...metadata, text: contextTexts[i] } };
  395. });
  396. return {
  397. contextTexts,
  398. sources: this.curateSources(sources),
  399. message: false,
  400. };
  401. },
  402. "namespace-stats": async function (reqBody = {}) {
  403. const { namespace = null } = reqBody;
  404. if (!namespace) throw new Error("namespace required");
  405. const { client } = await this.connect();
  406. if (!(await this.namespaceExists(client, namespace)))
  407. throw new Error("Namespace by that name does not exist.");
  408. const stats = await this.namespace(client, namespace);
  409. return stats
  410. ? stats
  411. : { message: "No stats were able to be fetched from DB for namespace" };
  412. },
  413. "delete-namespace": async function (reqBody = {}) {
  414. const { namespace = null } = reqBody;
  415. const { client } = await this.connect();
  416. if (!(await this.namespaceExists(client, namespace)))
  417. throw new Error("Namespace by that name does not exist.");
  418. await this.deleteVectorsInNamespace(client, namespace);
  419. return {
  420. message: `Namespace ${namespace} was deleted.`,
  421. };
  422. },
  423. reset: async function () {
  424. const { client } = await this.connect();
  425. const fs = require("fs");
  426. fs.rm(`${client.uri}`, { recursive: true }, () => null);
  427. return { reset: true };
  428. },
  429. curateSources: function (sources = []) {
  430. const documents = [];
  431. for (const source of sources) {
  432. const { text, vector: _v, _distance: _d, ...rest } = source;
  433. const metadata = rest.hasOwnProperty("metadata") ? rest.metadata : rest;
  434. if (Object.keys(metadata).length > 0) {
  435. documents.push({
  436. ...metadata,
  437. ...(text ? { text } : {}),
  438. });
  439. }
  440. }
  441. return documents;
  442. },
  443. };
  444. module.exports.LanceDb = LanceDb;