You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
3.3 KiB

11 months ago
  1. const OpenAI = require("openai");
  2. const Provider = require("./ai-provider.js");
  3. const InheritMultiple = require("./helpers/classes.js");
  4. const UnTooled = require("./helpers/untooled.js");
  5. const {
  6. parseLMStudioBasePath,
  7. } = require("../../../AiProviders/lmStudio/index.js");
  8. /**
  9. * The agent provider for the LMStudio.
  10. */
  11. class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) {
  12. model;
  13. /**
  14. *
  15. * @param {{model?: string}} config
  16. */
  17. constructor(config = {}) {
  18. super();
  19. const model =
  20. config?.model || process.env.LMSTUDIO_MODEL_PREF || "Loaded from Chat UI";
  21. const client = new OpenAI({
  22. baseURL: parseLMStudioBasePath(process.env.LMSTUDIO_BASE_PATH),
  23. apiKey: null,
  24. maxRetries: 3,
  25. });
  26. this._client = client;
  27. this.model = model;
  28. this.verbose = true;
  29. }
  30. get client() {
  31. return this._client;
  32. }
  33. async #handleFunctionCallChat({ messages = [] }) {
  34. return await this.client.chat.completions
  35. .create({
  36. model: this.model,
  37. temperature: 0,
  38. messages,
  39. })
  40. .then((result) => {
  41. if (!result.hasOwnProperty("choices"))
  42. throw new Error("LMStudio chat: No results!");
  43. if (result.choices.length === 0)
  44. throw new Error("LMStudio chat: No results length!");
  45. return result.choices[0].message.content;
  46. })
  47. .catch((_) => {
  48. return null;
  49. });
  50. }
  51. /**
  52. * Create a completion based on the received messages.
  53. *
  54. * @param messages A list of messages to send to the API.
  55. * @param functions
  56. * @returns The completion.
  57. */
  58. async complete(messages, functions = null) {
  59. try {
  60. let completion;
  61. if (functions.length > 0) {
  62. const { toolCall, text } = await this.functionCall(
  63. messages,
  64. functions,
  65. this.#handleFunctionCallChat.bind(this)
  66. );
  67. if (toolCall !== null) {
  68. this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
  69. this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
  70. return {
  71. result: null,
  72. functionCall: {
  73. name: toolCall.name,
  74. arguments: toolCall.arguments,
  75. },
  76. cost: 0,
  77. };
  78. }
  79. completion = { content: text };
  80. }
  81. if (!completion?.content) {
  82. this.providerLog(
  83. "Will assume chat completion without tool call inputs."
  84. );
  85. const response = await this.client.chat.completions.create({
  86. model: this.model,
  87. messages: this.cleanMsgs(messages),
  88. });
  89. completion = response.choices[0].message;
  90. }
  91. // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
  92. // from calling the exact same function over and over in a loop within a single chat exchange
  93. // _but_ we should enable it to call previously used tools in a new chat interaction.
  94. this.deduplicator.reset("runs");
  95. return {
  96. result: completion.content,
  97. cost: 0,
  98. };
  99. } catch (error) {
  100. throw error;
  101. }
  102. }
  103. /**
  104. * Get the cost of the completion.
  105. *
  106. * @param _usage The completion to get the cost for.
  107. * @returns The cost of the completion.
  108. * Stubbed since LMStudio has no cost basis.
  109. */
  110. getCost(_usage) {
  111. return 0;
  112. }
  113. }
  114. module.exports = LMStudioProvider;