You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

117 lines
3.2 KiB

11 months ago
  1. const OpenAI = require("openai");
  2. const Provider = require("./ai-provider.js");
  3. const InheritMultiple = require("./helpers/classes.js");
  4. const UnTooled = require("./helpers/untooled.js");
  5. /**
  6. * The agent provider for the Nvidia NIM provider.
  7. * We wrap Nvidia NIM in UnTooled because its tool-calling may not be supported for specific models and this normalizes that.
  8. */
  9. class NvidiaNimProvider extends InheritMultiple([Provider, UnTooled]) {
  10. model;
  11. constructor(config = {}) {
  12. const { model } = config;
  13. super();
  14. const client = new OpenAI({
  15. baseURL: process.env.NVIDIA_NIM_LLM_BASE_PATH,
  16. apiKey: null,
  17. maxRetries: 0,
  18. });
  19. this._client = client;
  20. this.model = model;
  21. this.verbose = true;
  22. }
  23. get client() {
  24. return this._client;
  25. }
  26. async #handleFunctionCallChat({ messages = [] }) {
  27. return await this.client.chat.completions
  28. .create({
  29. model: this.model,
  30. temperature: 0,
  31. messages,
  32. })
  33. .then((result) => {
  34. if (!result.hasOwnProperty("choices"))
  35. throw new Error("NVIDIA NIM chat: No results!");
  36. if (result.choices.length === 0)
  37. throw new Error("NVIDIA NIM chat: No results length!");
  38. return result.choices[0].message.content;
  39. })
  40. .catch((_) => {
  41. return null;
  42. });
  43. }
  44. /**
  45. * Create a completion based on the received messages.
  46. *
  47. * @param messages A list of messages to send to the API.
  48. * @param functions
  49. * @returns The completion.
  50. */
  51. async complete(messages, functions = null) {
  52. try {
  53. let completion;
  54. if (functions.length > 0) {
  55. const { toolCall, text } = await this.functionCall(
  56. messages,
  57. functions,
  58. this.#handleFunctionCallChat.bind(this)
  59. );
  60. if (toolCall !== null) {
  61. this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
  62. this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
  63. return {
  64. result: null,
  65. functionCall: {
  66. name: toolCall.name,
  67. arguments: toolCall.arguments,
  68. },
  69. cost: 0,
  70. };
  71. }
  72. completion = { content: text };
  73. }
  74. if (!completion?.content) {
  75. this.providerLog(
  76. "Will assume chat completion without tool call inputs."
  77. );
  78. const response = await this.client.chat.completions.create({
  79. model: this.model,
  80. messages: this.cleanMsgs(messages),
  81. });
  82. completion = response.choices[0].message;
  83. }
  84. // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
  85. // from calling the exact same function over and over in a loop within a single chat exchange
  86. // _but_ we should enable it to call previously used tools in a new chat interaction.
  87. this.deduplicator.reset("runs");
  88. return {
  89. result: completion.content,
  90. cost: 0,
  91. };
  92. } catch (error) {
  93. throw error;
  94. }
  95. }
  96. /**
  97. * Get the cost of the completion.
  98. *
  99. * @param _usage The completion to get the cost for.
  100. * @returns The cost of the completion.
  101. */
  102. getCost(_usage) {
  103. return 0;
  104. }
  105. }
  106. module.exports = NvidiaNimProvider;