diff --git a/packages/bbrt/src/llm/conversation.ts b/packages/bbrt/src/llm/conversation.ts index 4facef3ede..19cc4790fb 100644 --- a/packages/bbrt/src/llm/conversation.ts +++ b/packages/bbrt/src/llm/conversation.ts @@ -24,6 +24,8 @@ import type { JsonSerializableObject } from "../util/json-serializable.js"; import { coercePresentableError } from "../util/presentable-error.js"; import type { Result } from "../util/result.js"; +const MAX_TOOL_ITERATIONS = 5; + export interface ConversationOptions { state: ReactiveSessionState; drivers: Map; @@ -99,25 +101,22 @@ export class Conversation { const done = (async (): Promise => { const activeTools = await this.#getActiveTools(); - const { functionCalls } = await this.#callModel( - driver.value, - activeTools, - initialTimestamp, - systemPrompt - ); - if (functionCalls.length > 0) { - await Promise.all( - functionCalls.map((call) => - this.#executeFunctionCall(call, activeTools) + let remainingModelCalls = 1 + Math.max(0, MAX_TOOL_ITERATIONS); + let functionCalls: ReactiveFunctionCallState[]; + do { + const allowFunctionCalls = remainingModelCalls > 1; + functionCalls = ( + await this.#callModel( + driver.value, + allowFunctionCalls ? activeTools : undefined, + systemPrompt ) - ); - await this.#callModel( - driver.value, - undefined, - this.#clock.now(), - systemPrompt - ); - } + ).functionCalls; + if (functionCalls.length > 0) { + await this.#executeFunctionCalls(functionCalls, activeTools); + } + remainingModelCalls--; + } while (functionCalls.length > 0); this.#status = "ready"; })(); @@ -162,9 +161,9 @@ export class Conversation { async #callModel( driver: BBRTDriver, tools: Map | undefined, - timestamp: number, systemPrompt: string ): Promise<{ functionCalls: ReactiveFunctionCallState[] }> { + const timestamp = this.#clock.now(); // TODO(aomarks) This is a little weird. The natural thing to do would seem // to be to create a ReactiveSessionEventTurn, and pass it in. But in fact // our State constructors treat all initializer data as pure data, so that @@ -187,7 +186,8 @@ export class Conversation { this.state.events.push(event); const turn = (event.detail as ReactiveSessionEventTurn).turn; const functionCalls = []; - // Don't include the pending turn we just created. + // Don't include the pending turn we just created. We want the user to see + // it, but not the model. const slice = this.state.turns.slice(0, -1); try { const chunks = driver.send({ @@ -197,7 +197,7 @@ export class Conversation { }); for await (const chunk of chunks) { if (chunk.kind === "function-call") { - if (tools !== undefined) { + if (tools?.size) { const call = new ReactiveFunctionCallState(chunk.call); functionCalls.push(call); turn.chunks.push({ @@ -227,6 +227,15 @@ export class Conversation { return { functionCalls }; } + #executeFunctionCalls( + calls: ReactiveFunctionCallState[], + tools: Map + ): Promise { + return Promise.all( + calls.map((call) => this.#executeFunctionCall(call, tools)) + ); + } + async #executeFunctionCall( call: ReactiveFunctionCallState, tools: Map