You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

483 lines
16 KiB

11 months ago
  1. const { default: weaviate } = require("weaviate-ts-client");
  2. const { TextSplitter } = require("../../TextSplitter");
  3. const { SystemSettings } = require("../../../models/systemSettings");
  4. const { storeVectorResult, cachedVectorInformation } = require("../../files");
  5. const { v4: uuidv4 } = require("uuid");
  6. const { toChunks, getEmbeddingEngineSelection } = require("../../helpers");
  7. const { camelCase } = require("../../helpers/camelcase");
  8. const { sourceIdentifier } = require("../../chats");
  9. const Weaviate = {
  10. name: "Weaviate",
  11. connect: async function () {
  12. if (process.env.VECTOR_DB !== "weaviate")
  13. throw new Error("Weaviate::Invalid ENV settings");
  14. const weaviateUrl = new URL(process.env.WEAVIATE_ENDPOINT);
  15. const options = {
  16. scheme: weaviateUrl.protocol?.replace(":", "") || "http",
  17. host: weaviateUrl?.host,
  18. ...(process.env?.WEAVIATE_API_KEY?.length > 0
  19. ? { apiKey: new weaviate.ApiKey(process.env?.WEAVIATE_API_KEY) }
  20. : {}),
  21. };
  22. const client = weaviate.client(options);
  23. const isAlive = await await client.misc.liveChecker().do();
  24. if (!isAlive)
  25. throw new Error(
  26. "Weaviate::Invalid Alive signal received - is the service online?"
  27. );
  28. return { client };
  29. },
  30. heartbeat: async function () {
  31. await this.connect();
  32. return { heartbeat: Number(new Date()) };
  33. },
  34. totalVectors: async function () {
  35. const { client } = await this.connect();
  36. const collectionNames = await this.allNamespaces(client);
  37. var totalVectors = 0;
  38. for (const name of collectionNames) {
  39. totalVectors += await this.namespaceCountWithClient(client, name);
  40. }
  41. return totalVectors;
  42. },
  43. namespaceCountWithClient: async function (client, namespace) {
  44. try {
  45. const response = await client.graphql
  46. .aggregate()
  47. .withClassName(camelCase(namespace))
  48. .withFields("meta { count }")
  49. .do();
  50. return (
  51. response?.data?.Aggregate?.[camelCase(namespace)]?.[0]?.meta?.count || 0
  52. );
  53. } catch (e) {
  54. console.error(`Weaviate:namespaceCountWithClient`, e.message);
  55. return 0;
  56. }
  57. },
  58. namespaceCount: async function (namespace = null) {
  59. try {
  60. const { client } = await this.connect();
  61. const response = await client.graphql
  62. .aggregate()
  63. .withClassName(camelCase(namespace))
  64. .withFields("meta { count }")
  65. .do();
  66. return (
  67. response?.data?.Aggregate?.[camelCase(namespace)]?.[0]?.meta?.count || 0
  68. );
  69. } catch (e) {
  70. console.error(`Weaviate:namespaceCountWithClient`, e.message);
  71. return 0;
  72. }
  73. },
  74. similarityResponse: async function ({
  75. client,
  76. namespace,
  77. queryVector,
  78. similarityThreshold = 0.25,
  79. topN = 4,
  80. filterIdentifiers = [],
  81. }) {
  82. const result = {
  83. contextTexts: [],
  84. sourceDocuments: [],
  85. scores: [],
  86. };
  87. const weaviateClass = await this.namespace(client, namespace);
  88. const fields =
  89. weaviateClass.properties?.map((prop) => prop.name)?.join(" ") ?? "";
  90. const queryResponse = await client.graphql
  91. .get()
  92. .withClassName(camelCase(namespace))
  93. .withFields(`${fields} _additional { id certainty }`)
  94. .withNearVector({ vector: queryVector })
  95. .withLimit(topN)
  96. .do();
  97. const responses = queryResponse?.data?.Get?.[camelCase(namespace)];
  98. responses.forEach((response) => {
  99. // In Weaviate we have to pluck id from _additional and spread it into the rest
  100. // of the properties.
  101. const {
  102. _additional: { id, certainty },
  103. ...rest
  104. } = response;
  105. if (certainty < similarityThreshold) return;
  106. if (filterIdentifiers.includes(sourceIdentifier(rest))) {
  107. console.log(
  108. "Weaviate: A source was filtered from context as it's parent document is pinned."
  109. );
  110. return;
  111. }
  112. result.contextTexts.push(rest.text);
  113. result.sourceDocuments.push({ ...rest, id });
  114. result.scores.push(certainty);
  115. });
  116. return result;
  117. },
  118. allNamespaces: async function (client) {
  119. try {
  120. const { classes = [] } = await client.schema.getter().do();
  121. return classes.map((classObj) => classObj.class);
  122. } catch (e) {
  123. console.error("Weaviate::AllNamespace", e);
  124. return [];
  125. }
  126. },
  127. namespace: async function (client, namespace = null) {
  128. if (!namespace) throw new Error("No namespace value provided.");
  129. if (!(await this.namespaceExists(client, namespace))) return null;
  130. const weaviateClass = await client.schema
  131. .classGetter()
  132. .withClassName(camelCase(namespace))
  133. .do();
  134. return {
  135. ...weaviateClass,
  136. vectorCount: await this.namespaceCount(namespace),
  137. };
  138. },
  139. addVectors: async function (client, vectors = []) {
  140. const response = { success: true, errors: new Set([]) };
  141. const results = await client.batch
  142. .objectsBatcher()
  143. .withObjects(...vectors)
  144. .do();
  145. results.forEach((res) => {
  146. const { status, errors = [] } = res.result;
  147. if (status === "SUCCESS" || errors.length === 0) return;
  148. response.success = false;
  149. response.errors.add(errors.error?.[0]?.message || null);
  150. });
  151. response.errors = [...response.errors];
  152. return response;
  153. },
  154. hasNamespace: async function (namespace = null) {
  155. if (!namespace) return false;
  156. const { client } = await this.connect();
  157. const weaviateClasses = await this.allNamespaces(client);
  158. return weaviateClasses.includes(camelCase(namespace));
  159. },
  160. namespaceExists: async function (client, namespace = null) {
  161. if (!namespace) throw new Error("No namespace value provided.");
  162. const weaviateClasses = await this.allNamespaces(client);
  163. return weaviateClasses.includes(camelCase(namespace));
  164. },
  165. deleteVectorsInNamespace: async function (client, namespace = null) {
  166. await client.schema.classDeleter().withClassName(camelCase(namespace)).do();
  167. return true;
  168. },
  169. addDocumentToNamespace: async function (
  170. namespace,
  171. documentData = {},
  172. fullFilePath = null,
  173. skipCache = false
  174. ) {
  175. const { DocumentVectors } = require("../../../models/vectors");
  176. try {
  177. const {
  178. pageContent,
  179. docId,
  180. id: _id, // Weaviate will abort if `id` is present in properties
  181. ...metadata
  182. } = documentData;
  183. if (!pageContent || pageContent.length == 0) return false;
  184. console.log("Adding new vectorized document into namespace", namespace);
  185. if (skipCache) {
  186. const cacheResult = await cachedVectorInformation(fullFilePath);
  187. if (cacheResult.exists) {
  188. const { client } = await this.connect();
  189. const weaviateClassExits = await this.hasNamespace(namespace);
  190. if (!weaviateClassExits) {
  191. await client.schema
  192. .classCreator()
  193. .withClass({
  194. class: camelCase(namespace),
  195. description: `Class created by AnythingLLM named ${camelCase(
  196. namespace
  197. )}`,
  198. vectorizer: "none",
  199. })
  200. .do();
  201. }
  202. const { chunks } = cacheResult;
  203. const documentVectors = [];
  204. const vectors = [];
  205. for (const chunk of chunks) {
  206. // Before sending to Weaviate and saving the records to our db
  207. // we need to assign the id of each chunk that is stored in the cached file.
  208. chunk.forEach((chunk) => {
  209. const id = uuidv4();
  210. const flattenedMetadata = this.flattenObjectForWeaviate(
  211. chunk.properties ?? chunk.metadata
  212. );
  213. documentVectors.push({ docId, vectorId: id });
  214. const vectorRecord = {
  215. id,
  216. class: camelCase(namespace),
  217. vector: chunk.vector || chunk.values || [],
  218. properties: { ...flattenedMetadata },
  219. };
  220. vectors.push(vectorRecord);
  221. });
  222. const { success: additionResult, errors = [] } =
  223. await this.addVectors(client, vectors);
  224. if (!additionResult) {
  225. console.error("Weaviate::addVectors failed to insert", errors);
  226. throw new Error("Error embedding into Weaviate");
  227. }
  228. }
  229. await DocumentVectors.bulkInsert(documentVectors);
  230. return { vectorized: true, error: null };
  231. }
  232. }
  233. // If we are here then we are going to embed and store a novel document.
  234. // We have to do this manually as opposed to using LangChains `Chroma.fromDocuments`
  235. // because we then cannot atomically control our namespace to granularly find/remove documents
  236. // from vectordb.
  237. const EmbedderEngine = getEmbeddingEngineSelection();
  238. const textSplitter = new TextSplitter({
  239. chunkSize: TextSplitter.determineMaxChunkSize(
  240. await SystemSettings.getValueOrFallback({
  241. label: "text_splitter_chunk_size",
  242. }),
  243. EmbedderEngine?.embeddingMaxChunkLength
  244. ),
  245. chunkOverlap: await SystemSettings.getValueOrFallback(
  246. { label: "text_splitter_chunk_overlap" },
  247. 20
  248. ),
  249. chunkHeaderMeta: TextSplitter.buildHeaderMeta(metadata),
  250. });
  251. const textChunks = await textSplitter.splitText(pageContent);
  252. console.log("Chunks created from document:", textChunks.length);
  253. const documentVectors = [];
  254. const vectors = [];
  255. const vectorValues = await EmbedderEngine.embedChunks(textChunks);
  256. const submission = {
  257. ids: [],
  258. vectors: [],
  259. properties: [],
  260. };
  261. if (!!vectorValues && vectorValues.length > 0) {
  262. for (const [i, vector] of vectorValues.entries()) {
  263. const flattenedMetadata = this.flattenObjectForWeaviate(metadata);
  264. const vectorRecord = {
  265. class: camelCase(namespace),
  266. id: uuidv4(),
  267. vector: vector,
  268. // [DO NOT REMOVE]
  269. // LangChain will be unable to find your text if you embed manually and dont include the `text` key.
  270. // https://github.com/hwchase17/langchainjs/blob/5485c4af50c063e257ad54f4393fa79e0aff6462/langchain/src/vectorstores/weaviate.ts#L133
  271. properties: { ...flattenedMetadata, text: textChunks[i] },
  272. };
  273. submission.ids.push(vectorRecord.id);
  274. submission.vectors.push(vectorRecord.values);
  275. submission.properties.push(metadata);
  276. vectors.push(vectorRecord);
  277. documentVectors.push({ docId, vectorId: vectorRecord.id });
  278. }
  279. } else {
  280. throw new Error(
  281. "Could not embed document chunks! This document will not be recorded."
  282. );
  283. }
  284. const { client } = await this.connect();
  285. const weaviateClassExits = await this.hasNamespace(namespace);
  286. if (!weaviateClassExits) {
  287. await client.schema
  288. .classCreator()
  289. .withClass({
  290. class: camelCase(namespace),
  291. description: `Class created by AnythingLLM named ${camelCase(
  292. namespace
  293. )}`,
  294. vectorizer: "none",
  295. })
  296. .do();
  297. }
  298. if (vectors.length > 0) {
  299. const chunks = [];
  300. for (const chunk of toChunks(vectors, 500)) chunks.push(chunk);
  301. console.log("Inserting vectorized chunks into Weaviate collection.");
  302. const { success: additionResult, errors = [] } = await this.addVectors(
  303. client,
  304. vectors
  305. );
  306. if (!additionResult) {
  307. console.error("Weaviate::addVectors failed to insert", errors);
  308. throw new Error("Error embedding into Weaviate");
  309. }
  310. await storeVectorResult(chunks, fullFilePath);
  311. }
  312. await DocumentVectors.bulkInsert(documentVectors);
  313. return { vectorized: true, error: null };
  314. } catch (e) {
  315. console.error("addDocumentToNamespace", e.message);
  316. return { vectorized: false, error: e.message };
  317. }
  318. },
  319. deleteDocumentFromNamespace: async function (namespace, docId) {
  320. const { DocumentVectors } = require("../../../models/vectors");
  321. const { client } = await this.connect();
  322. if (!(await this.namespaceExists(client, namespace))) return;
  323. const knownDocuments = await DocumentVectors.where({ docId });
  324. if (knownDocuments.length === 0) return;
  325. for (const doc of knownDocuments) {
  326. await client.data
  327. .deleter()
  328. .withClassName(camelCase(namespace))
  329. .withId(doc.vectorId)
  330. .do();
  331. }
  332. const indexes = knownDocuments.map((doc) => doc.id);
  333. await DocumentVectors.deleteIds(indexes);
  334. return true;
  335. },
  336. performSimilaritySearch: async function ({
  337. namespace = null,
  338. input = "",
  339. LLMConnector = null,
  340. similarityThreshold = 0.25,
  341. topN = 4,
  342. filterIdentifiers = [],
  343. }) {
  344. if (!namespace || !input || !LLMConnector)
  345. throw new Error("Invalid request to performSimilaritySearch.");
  346. const { client } = await this.connect();
  347. if (!(await this.namespaceExists(client, namespace))) {
  348. return {
  349. contextTexts: [],
  350. sources: [],
  351. message: "Invalid query - no documents found for workspace!",
  352. };
  353. }
  354. const queryVector = await LLMConnector.embedTextInput(input);
  355. const { contextTexts, sourceDocuments } = await this.similarityResponse({
  356. client,
  357. namespace,
  358. queryVector,
  359. similarityThreshold,
  360. topN,
  361. filterIdentifiers,
  362. });
  363. const sources = sourceDocuments.map((metadata, i) => {
  364. return { ...metadata, text: contextTexts[i] };
  365. });
  366. return {
  367. contextTexts,
  368. sources: this.curateSources(sources),
  369. message: false,
  370. };
  371. },
  372. "namespace-stats": async function (reqBody = {}) {
  373. const { namespace = null } = reqBody;
  374. if (!namespace) throw new Error("namespace required");
  375. const { client } = await this.connect();
  376. const stats = await this.namespace(client, namespace);
  377. return stats
  378. ? stats
  379. : { message: "No stats were able to be fetched from DB for namespace" };
  380. },
  381. "delete-namespace": async function (reqBody = {}) {
  382. const { namespace = null } = reqBody;
  383. const { client } = await this.connect();
  384. const details = await this.namespace(client, namespace);
  385. await this.deleteVectorsInNamespace(client, namespace);
  386. return {
  387. message: `Namespace ${camelCase(namespace)} was deleted along with ${
  388. details?.vectorCount
  389. } vectors.`,
  390. };
  391. },
  392. reset: async function () {
  393. const { client } = await this.connect();
  394. const weaviateClasses = await this.allNamespaces(client);
  395. for (const weaviateClass of weaviateClasses) {
  396. await client.schema.classDeleter().withClassName(weaviateClass).do();
  397. }
  398. return { reset: true };
  399. },
  400. curateSources: function (sources = []) {
  401. const documents = [];
  402. for (const source of sources) {
  403. if (Object.keys(source).length > 0) {
  404. const metadata = source.hasOwnProperty("metadata")
  405. ? source.metadata
  406. : source;
  407. documents.push({ ...metadata });
  408. }
  409. }
  410. return documents;
  411. },
  412. flattenObjectForWeaviate: function (obj = {}) {
  413. // Note this function is not generic, it is designed specifically for Weaviate
  414. // https://weaviate.io/developers/weaviate/config-refs/datatypes#introduction
  415. // Credit to LangchainJS
  416. // https://github.com/hwchase17/langchainjs/blob/5485c4af50c063e257ad54f4393fa79e0aff6462/langchain/src/vectorstores/weaviate.ts#L11C1-L50C3
  417. const flattenedObject = {};
  418. for (const key in obj) {
  419. if (!Object.hasOwn(obj, key) || key === "id") {
  420. continue;
  421. }
  422. const value = obj[key];
  423. if (typeof obj[key] === "object" && !Array.isArray(value)) {
  424. const recursiveResult = this.flattenObjectForWeaviate(value);
  425. for (const deepKey in recursiveResult) {
  426. if (Object.hasOwn(obj, key)) {
  427. flattenedObject[`${key}_${deepKey}`] = recursiveResult[deepKey];
  428. }
  429. }
  430. } else if (Array.isArray(value)) {
  431. if (
  432. value.length > 0 &&
  433. typeof value[0] !== "object" &&
  434. // eslint-disable-next-line @typescript-eslint/no-explicit-any
  435. value.every((el) => typeof el === typeof value[0])
  436. ) {
  437. // Weaviate only supports arrays of primitive types,
  438. // where all elements are of the same type
  439. flattenedObject[key] = value;
  440. }
  441. } else {
  442. flattenedObject[key] = value;
  443. }
  444. }
  445. return flattenedObject;
  446. },
  447. };
  448. module.exports.Weaviate = Weaviate;