You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

817 lines
23 KiB

11 months ago
  1. const { EventEmitter } = require("events");
  2. const { APIError } = require("./error.js");
  3. const Providers = require("./providers/index.js");
  4. const { Telemetry } = require("../../../models/telemetry.js");
  5. /**
  6. * AIbitat is a class that manages the conversation between agents.
  7. * It is designed to solve a task with LLM.
  8. *
  9. * Guiding the chat through a graph of agents.
  10. */
  11. class AIbitat {
  12. emitter = new EventEmitter();
  13. provider = null;
  14. defaultProvider = null;
  15. defaultInterrupt;
  16. maxRounds;
  17. _chats;
  18. agents = new Map();
  19. channels = new Map();
  20. functions = new Map();
  21. constructor(props = {}) {
  22. const {
  23. chats = [],
  24. interrupt = "NEVER",
  25. maxRounds = 100,
  26. provider = "openai",
  27. handlerProps = {}, // Inherited props we can spread so aibitat can access.
  28. ...rest
  29. } = props;
  30. this._chats = chats;
  31. this.defaultInterrupt = interrupt;
  32. this.maxRounds = maxRounds;
  33. this.handlerProps = handlerProps;
  34. this.defaultProvider = {
  35. provider,
  36. ...rest,
  37. };
  38. this.provider = this.defaultProvider.provider;
  39. this.model = this.defaultProvider.model;
  40. }
  41. /**
  42. * Get the chat history between agents and channels.
  43. */
  44. get chats() {
  45. return this._chats;
  46. }
  47. /**
  48. * Install a plugin.
  49. */
  50. use(plugin) {
  51. plugin.setup(this);
  52. return this;
  53. }
  54. /**
  55. * Add a new agent to the AIbitat.
  56. *
  57. * @param name
  58. * @param config
  59. * @returns
  60. */
  61. agent(name = "", config = {}) {
  62. this.agents.set(name, config);
  63. return this;
  64. }
  65. /**
  66. * Add a new channel to the AIbitat.
  67. *
  68. * @param name
  69. * @param members
  70. * @param config
  71. * @returns
  72. */
  73. channel(name = "", members = [""], config = {}) {
  74. this.channels.set(name, {
  75. members,
  76. ...config,
  77. });
  78. return this;
  79. }
  80. /**
  81. * Get the specific agent configuration.
  82. *
  83. * @param agent The name of the agent.
  84. * @throws When the agent configuration is not found.
  85. * @returns The agent configuration.
  86. */
  87. getAgentConfig(agent = "") {
  88. const config = this.agents.get(agent);
  89. if (!config) {
  90. throw new Error(`Agent configuration "${agent}" not found`);
  91. }
  92. return {
  93. role: "You are a helpful AI assistant.",
  94. // role: `You are a helpful AI assistant.
  95. // Solve tasks using your coding and language skills.
  96. // In the following cases, suggest typescript code (in a typescript coding block) or shell script (in a sh coding block) for the user to execute.
  97. // 1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
  98. // 2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
  99. // Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
  100. // When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
  101. // If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
  102. // If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
  103. // When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
  104. // Reply "TERMINATE" when everything is done.`,
  105. ...config,
  106. };
  107. }
  108. /**
  109. * Get the specific channel configuration.
  110. *
  111. * @param channel The name of the channel.
  112. * @throws When the channel configuration is not found.
  113. * @returns The channel configuration.
  114. */
  115. getChannelConfig(channel = "") {
  116. const config = this.channels.get(channel);
  117. if (!config) {
  118. throw new Error(`Channel configuration "${channel}" not found`);
  119. }
  120. return {
  121. maxRounds: 10,
  122. role: "",
  123. ...config,
  124. };
  125. }
  126. /**
  127. * Get the members of a group.
  128. * @throws When the group is not defined as an array in the connections.
  129. * @param node The name of the group.
  130. * @returns The members of the group.
  131. */
  132. getGroupMembers(node = "") {
  133. const group = this.getChannelConfig(node);
  134. return group.members;
  135. }
  136. /**
  137. * Triggered when a plugin, socket, or command is aborted.
  138. *
  139. * @param listener
  140. * @returns
  141. */
  142. onAbort(listener = () => null) {
  143. this.emitter.on("abort", listener);
  144. return this;
  145. }
  146. /**
  147. * Abort the running of any plugins that may still be pending (Langchain summarize)
  148. */
  149. abort() {
  150. this.emitter.emit("abort", null, this);
  151. }
  152. /**
  153. * Triggered when a chat is terminated. After this, the chat can't be continued.
  154. *
  155. * @param listener
  156. * @returns
  157. */
  158. onTerminate(listener = () => null) {
  159. this.emitter.on("terminate", listener);
  160. return this;
  161. }
  162. /**
  163. * Terminate the chat. After this, the chat can't be continued.
  164. *
  165. * @param node Last node to chat with
  166. */
  167. terminate(node = "") {
  168. this.emitter.emit("terminate", node, this);
  169. }
  170. /**
  171. * Triggered when a chat is interrupted by a node.
  172. *
  173. * @param listener
  174. * @returns
  175. */
  176. onInterrupt(listener = () => null) {
  177. this.emitter.on("interrupt", listener);
  178. return this;
  179. }
  180. /**
  181. * Interruption the chat.
  182. *
  183. * @param route The nodes that participated in the interruption.
  184. * @returns
  185. */
  186. interrupt(route) {
  187. this._chats.push({
  188. ...route,
  189. state: "interrupt",
  190. });
  191. this.emitter.emit("interrupt", route, this);
  192. }
  193. /**
  194. * Triggered when a message is added to the chat history.
  195. * This can either be the first message or a reply to a message.
  196. *
  197. * @param listener
  198. * @returns
  199. */
  200. onMessage(listener = (chat) => null) {
  201. this.emitter.on("message", listener);
  202. return this;
  203. }
  204. /**
  205. * Register a new successful message in the chat history.
  206. * This will trigger the `onMessage` event.
  207. *
  208. * @param message
  209. */
  210. newMessage(message) {
  211. const chat = {
  212. ...message,
  213. state: "success",
  214. };
  215. this._chats.push(chat);
  216. this.emitter.emit("message", chat, this);
  217. }
  218. /**
  219. * Triggered when an error occurs during the chat.
  220. *
  221. * @param listener
  222. * @returns
  223. */
  224. onError(
  225. listener = (
  226. /**
  227. * The error that occurred.
  228. *
  229. * Native errors are:
  230. * - `APIError`
  231. * - `AuthorizationError`
  232. * - `UnknownError`
  233. * - `RateLimitError`
  234. * - `ServerError`
  235. */
  236. error = null,
  237. /**
  238. * The message when the error occurred.
  239. */
  240. {}
  241. ) => null
  242. ) {
  243. this.emitter.on("replyError", listener);
  244. return this;
  245. }
  246. /**
  247. * Register an error in the chat history.
  248. * This will trigger the `onError` event.
  249. *
  250. * @param route
  251. * @param error
  252. */
  253. newError(route, error) {
  254. const chat = {
  255. ...route,
  256. content: error instanceof Error ? error.message : String(error),
  257. state: "error",
  258. };
  259. this._chats.push(chat);
  260. this.emitter.emit("replyError", error, chat);
  261. }
  262. /**
  263. * Triggered when a chat is interrupted by a node.
  264. *
  265. * @param listener
  266. * @returns
  267. */
  268. onStart(listener = (chat, aibitat) => null) {
  269. this.emitter.on("start", listener);
  270. return this;
  271. }
  272. /**
  273. * Start a new chat.
  274. *
  275. * @param message The message to start the chat.
  276. */
  277. async start(message) {
  278. // register the message in the chat history
  279. this.newMessage(message);
  280. this.emitter.emit("start", message, this);
  281. // ask the node to reply
  282. await this.chat({
  283. to: message.from,
  284. from: message.to,
  285. });
  286. return this;
  287. }
  288. /**
  289. * Recursively chat between two nodes.
  290. *
  291. * @param route
  292. * @param keepAlive Whether to keep the chat alive.
  293. */
  294. async chat(route, keepAlive = true) {
  295. // check if the message is for a group
  296. // if it is, select the next node to chat with from the group
  297. // and then ask them to reply.
  298. if (this.channels.get(route.from)) {
  299. // select a node from the group
  300. let nextNode;
  301. try {
  302. nextNode = await this.selectNext(route.from);
  303. } catch (error) {
  304. if (error instanceof APIError) {
  305. return this.newError({ from: route.from, to: route.to }, error);
  306. }
  307. throw error;
  308. }
  309. if (!nextNode) {
  310. // TODO: should it throw an error or keep the chat alive when there is no node to chat with in the group?
  311. // maybe it should wrap up the chat and reply to the original node
  312. // For now, it will terminate the chat
  313. this.terminate(route.from);
  314. return;
  315. }
  316. const nextChat = {
  317. from: nextNode,
  318. to: route.from,
  319. };
  320. if (this.shouldAgentInterrupt(nextNode)) {
  321. this.interrupt(nextChat);
  322. return;
  323. }
  324. // get chats only from the group's nodes
  325. const history = this.getHistory({ to: route.from });
  326. const group = this.getGroupMembers(route.from);
  327. const rounds = history.filter((chat) => group.includes(chat.from)).length;
  328. const { maxRounds } = this.getChannelConfig(route.from);
  329. if (rounds >= maxRounds) {
  330. this.terminate(route.to);
  331. return;
  332. }
  333. await this.chat(nextChat);
  334. return;
  335. }
  336. // If it's a direct message, reply to the message
  337. let reply = "";
  338. try {
  339. reply = await this.reply(route);
  340. } catch (error) {
  341. if (error instanceof APIError) {
  342. return this.newError({ from: route.from, to: route.to }, error);
  343. }
  344. throw error;
  345. }
  346. if (
  347. reply === "TERMINATE" ||
  348. this.hasReachedMaximumRounds(route.from, route.to)
  349. ) {
  350. this.terminate(route.to);
  351. return;
  352. }
  353. const newChat = { to: route.from, from: route.to };
  354. if (
  355. reply === "INTERRUPT" ||
  356. (this.agents.get(route.to) && this.shouldAgentInterrupt(route.to))
  357. ) {
  358. this.interrupt(newChat);
  359. return;
  360. }
  361. if (keepAlive) {
  362. // keep the chat alive by replying to the other node
  363. await this.chat(newChat, true);
  364. }
  365. }
  366. /**
  367. * Check if the agent should interrupt the chat based on its configuration.
  368. *
  369. * @param agent
  370. * @returns {boolean} Whether the agent should interrupt the chat.
  371. */
  372. shouldAgentInterrupt(agent = "") {
  373. const config = this.getAgentConfig(agent);
  374. return this.defaultInterrupt === "ALWAYS" || config.interrupt === "ALWAYS";
  375. }
  376. /**
  377. * Select the next node to chat with from a group. The node will be selected based on the history of chats.
  378. * It will select the node that has not reached the maximum number of rounds yet and has not chatted with the channel in the last round.
  379. * If it could not determine the next node, it will return a random node.
  380. *
  381. * @param channel The name of the group.
  382. * @returns The name of the node to chat with.
  383. */
  384. async selectNext(channel = "") {
  385. // get all members of the group
  386. const nodes = this.getGroupMembers(channel);
  387. const channelConfig = this.getChannelConfig(channel);
  388. // TODO: move this to when the group is created
  389. // warn if the group is underpopulated
  390. if (nodes.length < 3) {
  391. console.warn(
  392. `- Group (${channel}) is underpopulated with ${nodes.length} agents. Direct communication would be more efficient.`
  393. );
  394. }
  395. // get the nodes that have not reached the maximum number of rounds
  396. const availableNodes = nodes.filter(
  397. (node) => !this.hasReachedMaximumRounds(channel, node)
  398. );
  399. // remove the last node that chatted with the channel, so it doesn't chat again
  400. const lastChat = this._chats.filter((c) => c.to === channel).at(-1);
  401. if (lastChat) {
  402. const index = availableNodes.indexOf(lastChat.from);
  403. if (index > -1) {
  404. availableNodes.splice(index, 1);
  405. }
  406. }
  407. // TODO: what should it do when there is no node to chat with?
  408. if (!availableNodes.length) return;
  409. // get the provider that will be used for the channel
  410. // if the channel has a provider, use that otherwise
  411. // use the GPT-4 because it has a better reasoning
  412. const provider = this.getProviderForConfig({
  413. // @ts-expect-error
  414. model: "gpt-4",
  415. ...this.defaultProvider,
  416. ...channelConfig,
  417. });
  418. const history = this.getHistory({ to: channel });
  419. // build the messages to send to the provider
  420. const messages = [
  421. {
  422. role: "system",
  423. content: channelConfig.role,
  424. },
  425. {
  426. role: "user",
  427. content: `You are in a role play game. The following roles are available:
  428. ${availableNodes
  429. .map((node) => `@${node}: ${this.getAgentConfig(node).role}`)
  430. .join("\n")}.
  431. Read the following conversation.
  432. CHAT HISTORY
  433. ${history.map((c) => `@${c.from}: ${c.content}`).join("\n")}
  434. Then select the next role from that is going to speak next.
  435. Only return the role.
  436. `,
  437. },
  438. ];
  439. // ask the provider to select the next node to chat with
  440. // and remove the @ from the response
  441. const { result } = await provider.complete(messages);
  442. const name = result?.replace(/^@/g, "");
  443. if (this.agents.get(name)) {
  444. return name;
  445. }
  446. // if the name is not in the nodes, return a random node
  447. return availableNodes[Math.floor(Math.random() * availableNodes.length)];
  448. }
  449. /**
  450. *
  451. * @param {string} pluginName this name of the plugin being called
  452. * @returns string of the plugin to be called compensating for children denoted by # in the string.
  453. * eg: sql-agent:list-database-connections
  454. * or is a custom plugin
  455. * eg: @@custom-plugin-name
  456. */
  457. #parseFunctionName(pluginName = "") {
  458. if (!pluginName.includes("#") && !pluginName.startsWith("@@"))
  459. return pluginName;
  460. if (pluginName.startsWith("@@")) return pluginName.replace("@@", "");
  461. return pluginName.split("#")[1];
  462. }
  463. /**
  464. * Check if the chat has reached the maximum number of rounds.
  465. */
  466. hasReachedMaximumRounds(from = "", to = "") {
  467. return this.getHistory({ from, to }).length >= this.maxRounds;
  468. }
  469. /**
  470. * Ask the for the AI provider to generate a reply to the chat.
  471. *
  472. * @param route.to The node that sent the chat.
  473. * @param route.from The node that will reply to the chat.
  474. */
  475. async reply(route) {
  476. // get the provider for the node that will reply
  477. const fromConfig = this.getAgentConfig(route.from);
  478. const chatHistory =
  479. // if it is sending message to a group, send the group chat history to the provider
  480. // otherwise, send the chat history between the two nodes
  481. this.channels.get(route.to)
  482. ? [
  483. {
  484. role: "user",
  485. content: `You are in a whatsapp group. Read the following conversation and then reply.
  486. Do not add introduction or conclusion to your reply because this will be a continuous conversation. Don't introduce yourself.
  487. CHAT HISTORY
  488. ${this.getHistory({ to: route.to })
  489. .map((c) => `@${c.from}: ${c.content}`)
  490. .join("\n")}
  491. @${route.from}:`,
  492. },
  493. ]
  494. : this.getHistory(route).map((c) => ({
  495. content: c.content,
  496. role: c.from === route.to ? "user" : "assistant",
  497. }));
  498. // build the messages to send to the provider
  499. const messages = [
  500. {
  501. content: fromConfig.role,
  502. role: "system",
  503. },
  504. // get the history of chats between the two nodes
  505. ...chatHistory,
  506. ];
  507. // get the functions that the node can call
  508. const functions = fromConfig.functions
  509. ?.map((name) => this.functions.get(this.#parseFunctionName(name)))
  510. .filter((a) => !!a);
  511. const provider = this.getProviderForConfig({
  512. ...this.defaultProvider,
  513. ...fromConfig,
  514. });
  515. // get the chat completion
  516. const content = await this.handleExecution(
  517. provider,
  518. messages,
  519. functions,
  520. route.from
  521. );
  522. this.newMessage({ ...route, content });
  523. return content;
  524. }
  525. async handleExecution(
  526. provider,
  527. messages = [],
  528. functions = [],
  529. byAgent = null
  530. ) {
  531. // get the chat completion
  532. const completion = await provider.complete(messages, functions);
  533. if (completion.functionCall) {
  534. const { name, arguments: args } = completion.functionCall;
  535. const fn = this.functions.get(name);
  536. // if provider hallucinated on the function name
  537. // ask the provider to complete again
  538. if (!fn) {
  539. return await this.handleExecution(
  540. provider,
  541. [
  542. ...messages,
  543. {
  544. name,
  545. role: "function",
  546. content: `Function "${name}" not found. Try again.`,
  547. },
  548. ],
  549. functions,
  550. byAgent
  551. );
  552. }
  553. // Execute the function and return the result to the provider
  554. fn.caller = byAgent || "agent";
  555. // For OSS LLMs we really need to keep tabs on what they are calling
  556. // so we can log it here.
  557. if (provider?.verbose) {
  558. this?.introspect?.(
  559. `[debug]: ${fn.caller} is attempting to call \`${name}\` tool`
  560. );
  561. this.handlerProps.log(
  562. `[debug]: ${fn.caller} is attempting to call \`${name}\` tool`
  563. );
  564. }
  565. const result = await fn.handler(args);
  566. Telemetry.sendTelemetry("agent_tool_call", { tool: name }, null, true);
  567. return await this.handleExecution(
  568. provider,
  569. [
  570. ...messages,
  571. {
  572. name,
  573. role: "function",
  574. content: result,
  575. },
  576. ],
  577. functions,
  578. byAgent
  579. );
  580. }
  581. return completion?.result;
  582. }
  583. /**
  584. * Continue the chat from the last interruption.
  585. * If the last chat was not an interruption, it will throw an error.
  586. * Provide a feedback where it was interrupted if you want to.
  587. *
  588. * @param feedback The feedback to the interruption if any.
  589. * @returns
  590. */
  591. async continue(feedback) {
  592. const lastChat = this._chats.at(-1);
  593. if (!lastChat || lastChat.state !== "interrupt") {
  594. throw new Error("No chat to continue");
  595. }
  596. // remove the last chat's that was interrupted
  597. this._chats.pop();
  598. const { from, to } = lastChat;
  599. if (this.hasReachedMaximumRounds(from, to)) {
  600. throw new Error("Maximum rounds reached");
  601. }
  602. if (feedback) {
  603. const message = {
  604. from,
  605. to,
  606. content: feedback,
  607. };
  608. // register the message in the chat history
  609. this.newMessage(message);
  610. // ask the node to reply
  611. await this.chat({
  612. to: message.from,
  613. from: message.to,
  614. });
  615. } else {
  616. await this.chat({ from, to });
  617. }
  618. return this;
  619. }
  620. /**
  621. * Retry the last chat that threw an error.
  622. * If the last chat was not an error, it will throw an error.
  623. */
  624. async retry() {
  625. const lastChat = this._chats.at(-1);
  626. if (!lastChat || lastChat.state !== "error") {
  627. throw new Error("No chat to retry");
  628. }
  629. // remove the last chat's that threw an error
  630. const { from, to } = this?._chats?.pop();
  631. await this.chat({ from, to });
  632. return this;
  633. }
  634. /**
  635. * Get the chat history between two nodes or all chats to/from a node.
  636. */
  637. getHistory({ from, to }) {
  638. return this._chats.filter((chat) => {
  639. const isSuccess = chat.state === "success";
  640. // return all chats to the node
  641. if (!from) {
  642. return isSuccess && chat.to === to;
  643. }
  644. // get all chats from the node
  645. if (!to) {
  646. return isSuccess && chat.from === from;
  647. }
  648. // check if the chat is between the two nodes
  649. const hasSent = chat.from === from && chat.to === to;
  650. const hasReceived = chat.from === to && chat.to === from;
  651. const mutual = hasSent || hasReceived;
  652. return isSuccess && mutual;
  653. });
  654. }
  655. /**
  656. * Get provider based on configurations.
  657. * If the provider is a string, it will return the default provider for that string.
  658. *
  659. * @param config The provider configuration.
  660. */
  661. getProviderForConfig(config) {
  662. if (typeof config.provider === "object") {
  663. return config.provider;
  664. }
  665. switch (config.provider) {
  666. case "openai":
  667. return new Providers.OpenAIProvider({ model: config.model });
  668. case "anthropic":
  669. return new Providers.AnthropicProvider({ model: config.model });
  670. case "lmstudio":
  671. return new Providers.LMStudioProvider({ model: config.model });
  672. case "ollama":
  673. return new Providers.OllamaProvider({ model: config.model });
  674. case "groq":
  675. return new Providers.GroqProvider({ model: config.model });
  676. case "togetherai":
  677. return new Providers.TogetherAIProvider({ model: config.model });
  678. case "azure":
  679. return new Providers.AzureOpenAiProvider({ model: config.model });
  680. case "koboldcpp":
  681. return new Providers.KoboldCPPProvider({});
  682. case "localai":
  683. return new Providers.LocalAIProvider({ model: config.model });
  684. case "openrouter":
  685. return new Providers.OpenRouterProvider({ model: config.model });
  686. case "mistral":
  687. return new Providers.MistralProvider({ model: config.model });
  688. case "generic-openai":
  689. return new Providers.GenericOpenAiProvider({ model: config.model });
  690. case "perplexity":
  691. return new Providers.PerplexityProvider({ model: config.model });
  692. case "textgenwebui":
  693. return new Providers.TextWebGenUiProvider({});
  694. case "bedrock":
  695. return new Providers.AWSBedrockProvider({});
  696. case "fireworksai":
  697. return new Providers.FireworksAIProvider({ model: config.model });
  698. case "nvidia-nim":
  699. return new Providers.NvidiaNimProvider({ model: config.model });
  700. case "deepseek":
  701. return new Providers.DeepSeekProvider({ model: config.model });
  702. case "litellm":
  703. return new Providers.LiteLLMProvider({ model: config.model });
  704. case "apipie":
  705. return new Providers.ApiPieProvider({ model: config.model });
  706. case "xai":
  707. return new Providers.XAIProvider({ model: config.model });
  708. case "novita":
  709. return new Providers.NovitaProvider({ model: config.model });
  710. default:
  711. throw new Error(
  712. `Unknown provider: ${config.provider}. Please use "openai"`
  713. );
  714. }
  715. }
  716. /**
  717. * Register a new function to be called by the AIbitat agents.
  718. * You are also required to specify the which node can call the function.
  719. * @param functionConfig The function configuration.
  720. */
  721. function(functionConfig) {
  722. this.functions.set(functionConfig.name, functionConfig);
  723. return this;
  724. }
  725. }
  726. module.exports = AIbitat;