diff --git a/packages/cli/src/core/gemini-client.ts b/packages/cli/src/core/gemini-client.ts index be338754..19dba40f 100644 --- a/packages/cli/src/core/gemini-client.ts +++ b/packages/cli/src/core/gemini-client.ts @@ -15,31 +15,17 @@ import { Content, } from '@google/genai'; import { CoreSystemPrompt } from './prompts.js'; -import { - type ToolCallEvent, - type ToolCallConfirmationDetails, - ToolCallStatus, -} from '../ui/types.js'; import process from 'node:process'; import { toolRegistry } from '../tools/tool-registry.js'; -import { ToolResult } from '../tools/tools.js'; import { getFolderStructure } from '../utils/getFolderStructure.js'; import { GeminiEventType, GeminiStream } from './gemini-stream.js'; import { Config } from '../config/config.js'; - -type ToolExecutionOutcome = { - callId: string; - name: string; - args: Record; - result?: ToolResult; - error?: Error; - confirmationDetails?: ToolCallConfirmationDetails; -}; +import { Turn } from './turn.js'; export class GeminiClient { private config: Config; private ai: GoogleGenAI; - private defaultHyperParameters: GenerateContentConfig = { + private generateContentConfig: GenerateContentConfig = { temperature: 0, topP: 1, }; @@ -50,14 +36,9 @@ export class GeminiClient { this.ai = new GoogleGenAI({ apiKey: config.getApiKey() }); } - async startChat(): Promise { - const tools = toolRegistry.getToolSchemas(); - const model = this.config.getModel(); - - // --- Get environmental information --- + private async getEnvironment(): Promise { const cwd = process.cwd(); const today = new Date().toLocaleDateString(undefined, { - // Use locale-aware date formatting weekday: 'long', year: 'numeric', month: 'long', @@ -65,33 +46,37 @@ export class GeminiClient { }); const platform = process.platform; - // --- Format information into a conversational multi-line string --- const folderStructure = await getFolderStructure(cwd); - // --- End folder structure formatting ---) - const initialContextText = ` -Okay, just setting up the context for our chat. -Today is ${today}. -My operating system is: ${platform} -I'm currently working in the directory: ${cwd} -${folderStructure} - `.trim(); - const initialContextPart: Part = { text: initialContextText }; - // --- End environmental information formatting --- + const context = ` + Okay, just setting up the context for our chat. + Today is ${today}. + My operating system is: ${platform} + I'm currently working in the directory: ${cwd} + ${folderStructure} + `.trim(); + + return { text: context }; + } + + async startChat(): Promise { + const envPart = await this.getEnvironment(); + const model = this.config.getModel(); + const tools = toolRegistry.getToolSchemas(); try { const chat = this.ai.chats.create({ model, config: { systemInstruction: CoreSystemPrompt, - ...this.defaultHyperParameters, + ...this.generateContentConfig, tools, }, history: [ // --- Add the context as a single part in the initial user message --- { role: 'user', - parts: [initialContextPart], // Pass the single Part object in an array + parts: [envPart], // Pass the single Part object in an array }, // --- Add an empty model response to balance the history --- { @@ -109,308 +94,113 @@ ${folderStructure} } } - addMessageToHistory(chat: Chat, message: Content): void { - const history = chat.getHistory(); - history.push(message); - } - async *sendMessageStream( chat: Chat, request: PartListUnion, signal?: AbortSignal, ): GeminiStream { - let currentMessageToSend: PartListUnion = request; let turns = 0; try { while (turns < this.MAX_TURNS) { turns++; - const resultStream = await chat.sendMessageStream({ - message: currentMessageToSend, - }); - let functionResponseParts: Part[] = []; - let pendingToolCalls: Array<{ - callId: string; - name: string; - args: Record; - }> = []; - let yieldedTextInTurn = false; - const chunksForDebug = []; + // A turn either yields a text response or returns + // function responses to be used in the next turn. + // This callsite is responsible to handle the buffered + // function responses and use it on the next turn. + const turn = new Turn(chat); + const resultStream = turn.run(request, signal); - for await (const chunk of resultStream) { - chunksForDebug.push(chunk); - if (signal?.aborted) { - const abortError = new Error( - 'Request cancelled by user during stream.', - ); - abortError.name = 'AbortError'; - throw abortError; - } - - const functionCalls = chunk.functionCalls; - if (functionCalls && functionCalls.length > 0) { - for (const call of functionCalls) { - const callId = - call.id ?? - `${call.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`; - const name = call.name || 'undefined_tool_name'; - const args = (call.args || {}) as Record; - - pendingToolCalls.push({ callId, name, args }); - const evtValue: ToolCallEvent = { - type: 'tool_call', - status: ToolCallStatus.Pending, - callId, - name, - args, - resultDisplay: undefined, - confirmationDetails: undefined, - }; - yield { - type: GeminiEventType.ToolCallInfo, - value: evtValue, - }; - } - } else { - const text = chunk.text; - if (text) { - yieldedTextInTurn = true; - yield { - type: GeminiEventType.Content, - value: text, - }; - } - } + for await (const event of resultStream) { + yield event; + } + const fnResponses = turn.getFunctionResponses(); + if (fnResponses.length > 0) { + request = fnResponses; + continue; // use the responses in the next turn } - if (pendingToolCalls.length > 0) { - const toolPromises: Array> = - pendingToolCalls.map(async (pendingToolCall) => { - const tool = toolRegistry.getTool(pendingToolCall.name); + const history = chat.getHistory(); + const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you). - if (!tool) { - // Directly return error outcome if tool not found - return { - ...pendingToolCall, - error: new Error( - `Tool "${pendingToolCall.name}" not found or is not registered.`, - ), - }; - } + **Decision Rules (apply in order):** - try { - const confirmation = await tool.shouldConfirmExecute( - pendingToolCall.args, - ); - if (confirmation) { - return { - ...pendingToolCall, - confirmationDetails: confirmation, - }; - } - } catch (error) { - return { - ...pendingToolCall, - error: new Error( - `Tool failed to check tool confirmation: ${error}`, - ), - }; - } + 1. **Model Continues:** If your last response explicitly states an immediate next action *you* intend to take (e.g., "Next, I will...", "Now I'll process...", "Moving on to analyze...", indicates an intended tool call that didn't execute), OR if the response seems clearly incomplete (cut off mid-thought without a natural conclusion), then the **'model'** should speak next. + 2. **Question to User:** If your last response ends with a direct question specifically addressed *to the user*, then the **'user'** should speak next. + 3. **Waiting for User:** If your last response completed a thought, statement, or task *and* does not meet the criteria for Rule 1 (Model Continues) or Rule 2 (Question to User), it implies a pause expecting user input or reaction. In this case, the **'user'** should speak next. - try { - const result = await tool.execute(pendingToolCall.args); - return { ...pendingToolCall, result }; - } catch (error) { - return { - ...pendingToolCall, - error: new Error(`Tool failed to execute: ${error}`), - }; - } - }); - const toolExecutionOutcomes: ToolExecutionOutcome[] = - await Promise.all(toolPromises); + **Output Format:** - for (const executedTool of toolExecutionOutcomes) { - const { callId, name, args, result, error, confirmationDetails } = - executedTool; + Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure. - if (error) { - const errorMessage = error?.message || String(error); - yield { - type: GeminiEventType.Content, - value: `[Error invoking tool ${name}: ${errorMessage}]`, - }; - } else if ( - result && - typeof result === 'object' && - result !== null && - 'error' in result - ) { - const errorMessage = String(result.error); - yield { - type: GeminiEventType.Content, - value: `[Error executing tool ${name}: ${errorMessage}]`, - }; - } else { - const status = confirmationDetails - ? ToolCallStatus.Confirming - : ToolCallStatus.Invoked; - const evtValue: ToolCallEvent = { - type: 'tool_call', - status, - callId, - name, - args, - resultDisplay: result?.returnDisplay, - confirmationDetails, - }; - yield { - type: GeminiEventType.ToolCallInfo, - value: evtValue, - }; - } - } - - pendingToolCalls = []; - - const waitingOnConfirmations = - toolExecutionOutcomes.filter( - (outcome) => outcome.confirmationDetails, - ).length > 0; - if (waitingOnConfirmations) { - // Stop processing content, wait for user. - // TODO: Kill token processing once API supports signals. - break; - } - - functionResponseParts = toolExecutionOutcomes.map( - (executedTool: ToolExecutionOutcome): Part => { - const { name, result, error } = executedTool; - const output = { output: result?.llmContent }; - let toolOutcomePayload: Record; - - if (error) { - const errorMessage = error?.message || String(error); - toolOutcomePayload = { - error: `Invocation failed: ${errorMessage}`, - }; - console.error( - `[Turn ${turns}] Critical error invoking tool ${name}:`, - error, - ); - } else if ( - result && - typeof result === 'object' && - result !== null && - 'error' in result - ) { - toolOutcomePayload = output; - console.warn( - `[Turn ${turns}] Tool ${name} returned an error structure:`, - result.error, - ); - } else { - toolOutcomePayload = output; - } - - return { - functionResponse: { - name, - id: executedTool.callId, - response: toolOutcomePayload, - }, - }; - }, - ); - currentMessageToSend = functionResponseParts; - } else if (yieldedTextInTurn) { - const history = chat.getHistory(); - const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you). - -**Decision Rules (apply in order):** - -1. **Model Continues:** If your last response explicitly states an immediate next action *you* intend to take (e.g., "Next, I will...", "Now I'll process...", "Moving on to analyze...", indicates an intended tool call that didn't execute), OR if the response seems clearly incomplete (cut off mid-thought without a natural conclusion), then the **'model'** should speak next. -2. **Question to User:** If your last response ends with a direct question specifically addressed *to the user*, then the **'user'** should speak next. -3. **Waiting for User:** If your last response completed a thought, statement, or task *and* does not meet the criteria for Rule 1 (Model Continues) or Rule 2 (Question to User), it implies a pause expecting user input or reaction. In this case, the **'user'** should speak next. - -**Output Format:** - -Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure. - -\`\`\`json -{ - "type": "object", - "properties": { - "reasoning": { + \`\`\`json + { + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn." + }, + "next_speaker": { "type": "string", - "description": "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn." + "enum": ["user", "model"], + "description": "Who should speak next based *only* on the preceding turn and the decision rules." + } }, - "next_speaker": { - "type": "string", - "enum": ["user", "model"], - "description": "Who should speak next based *only* on the preceding turn and the decision rules." - } - }, - "required": ["next_speaker", "reasoning"] -\`\`\` -}`; + "required": ["next_speaker", "reasoning"] + \`\`\` + }`; - // Schema Idea - const responseSchema: SchemaUnion = { - type: Type.OBJECT, - properties: { - reasoning: { - type: Type.STRING, - description: - "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.", - }, - next_speaker: { - type: Type.STRING, - enum: ['user', 'model'], // Enforce the choices - description: - 'Who should speak next based *only* on the preceding turn and the decision rules', - }, + // Schema Idea + const responseSchema: SchemaUnion = { + type: Type.OBJECT, + properties: { + reasoning: { + type: Type.STRING, + description: + "Brief explanation justifying the 'next_speaker' choice based *strictly* on the applicable rule and the content/structure of the preceding turn.", }, - required: ['reasoning', 'next_speaker'], - }; + next_speaker: { + type: Type.STRING, + enum: ['user', 'model'], // Enforce the choices + description: + 'Who should speak next based *only* on the preceding turn and the decision rules', + }, + }, + required: ['reasoning', 'next_speaker'], + }; - try { - // Use the new generateJson method, passing the history and the check prompt - const parsedResponse = await this.generateJson( - [ - ...history, - { - role: 'user', - parts: [{ text: checkPrompt }], - }, - ], - responseSchema, - ); + try { + // Use the new generateJson method, passing the history and the check prompt + const parsedResponse = await this.generateJson( + [ + ...history, + { + role: 'user', + parts: [{ text: checkPrompt }], + }, + ], + responseSchema, + ); - // Safely extract the next speaker value - const nextSpeaker: string | undefined = - typeof parsedResponse?.next_speaker === 'string' - ? parsedResponse.next_speaker - : undefined; + // Safely extract the next speaker value + const nextSpeaker: string | undefined = + typeof parsedResponse?.next_speaker === 'string' + ? parsedResponse.next_speaker + : undefined; - if (nextSpeaker === 'model') { - currentMessageToSend = { text: 'alright' }; // Or potentially a more meaningful continuation prompt - } else { - // 'user' should speak next, or value is missing/invalid. End the turn. - break; - } - } catch (error) { - console.error( - `[Turn ${turns}] Failed to get or parse next speaker check:`, - error, - ); - // If the check fails, assume user should speak next to avoid infinite loops + if (nextSpeaker === 'model') { + request = { text: 'alright' }; // Or potentially a more meaningful continuation prompt + } else { + // 'user' should speak next, or value is missing/invalid. End the turn. break; } - } else { - console.warn( - `[Turn ${turns}] No text or function calls received from Gemini. Ending interaction.`, + } catch (error) { + console.error( + `[Turn ${turns}] Failed to get or parse next speaker check:`, + error, ); + // If the check fails, assume user should speak next to avoid infinite loops break; } } @@ -426,6 +216,8 @@ Respond *only* in JSON format according to the following schema. Do not include }; } } catch (error: unknown) { + // TODO(jbd): There is so much of packing/unpacking of error types. + // Figure out a way to remove the redundant work. if (error instanceof Error && error.name === 'AbortError') { console.log('Gemini stream request aborted by user.'); throw error; @@ -457,7 +249,7 @@ Respond *only* in JSON format according to the following schema. Do not include const result = await this.ai.models.generateContent({ model, config: { - ...this.defaultHyperParameters, + ...this.generateContentConfig, systemInstruction: CoreSystemPrompt, responseSchema: schema, responseMimeType: 'application/json', diff --git a/packages/cli/src/core/turn.ts b/packages/cli/src/core/turn.ts new file mode 100644 index 00000000..e8c4ef78 --- /dev/null +++ b/packages/cli/src/core/turn.ts @@ -0,0 +1,233 @@ +import { + Part, + Chat, + PartListUnion, + GenerateContentResponse, + FunctionCall, +} from '@google/genai'; +import { + type ToolCallConfirmationDetails, + ToolCallStatus, + ToolCallEvent, +} from '../ui/types.js'; +import { ToolResult } from '../tools/tools.js'; +import { toolRegistry } from '../tools/tool-registry.js'; +import { GeminiEventType, GeminiStream } from './gemini-stream.js'; + +export type ToolExecutionOutcome = { + callId: string; + name: string; + args: Record; + result?: ToolResult; + error?: Error; + confirmationDetails?: ToolCallConfirmationDetails; +}; + +// TODO(jbd): Move ToolExecutionOutcome to somewhere else? + +// A turn manages the agentic loop turn. +// Turn.run emits throught the turn events that could be used +// as immediate feedback to the user. +export class Turn { + private readonly chat: Chat; + private pendingToolCalls: Array<{ + callId: string; + name: string; + args: Record; + }>; + private fnResponses: Part[]; + private debugResponses: GenerateContentResponse[]; + + constructor(chat: Chat) { + this.chat = chat; + this.pendingToolCalls = []; + this.fnResponses = []; + this.debugResponses = []; + } + + async *run(req: PartListUnion, signal?: AbortSignal): GeminiStream { + const responseStream = await this.chat.sendMessageStream({ + message: req, + }); + for await (const resp of responseStream) { + this.debugResponses.push(resp); + if (signal?.aborted) { + throw this.abortError(); + } + if (resp.text) { + yield { + type: GeminiEventType.Content, + value: resp.text, + }; + continue; + } + if (!resp.functionCalls) { + continue; + } + for (const fnCall of resp.functionCalls) { + for await (const event of this.handlePendingFunctionCall(fnCall)) { + yield event; + } + } + + // Create promises to be able to wait for executions to complete. + const toolPromises = this.pendingToolCalls.map( + async (pendingToolCall) => { + const tool = toolRegistry.getTool(pendingToolCall.name); + if (!tool) { + return { + ...pendingToolCall, + error: new Error( + `Tool "${pendingToolCall.name}" not found or is not registered.`, + ), + }; + } + const shouldConfirm = await tool.shouldConfirmExecute( + pendingToolCall.args, + ); + if (shouldConfirm) { + return { + // TODO(jbd): Should confirm isn't confirmation details. + ...pendingToolCall, + confirmationDetails: shouldConfirm, + }; + } + const result = await tool.execute(pendingToolCall.args); + return { ...pendingToolCall, result }; + }, + ); + const outcomes = await Promise.all(toolPromises); + for await (const event of this.handleToolOutcomes(outcomes)) { + yield event; + } + this.pendingToolCalls = []; + + // TODO(jbd): Make it harder for the caller to ignore the + // buffered function responses. + this.fnResponses = this.buildFunctionResponses(outcomes); + } + } + + private async *handlePendingFunctionCall(fnCall: FunctionCall): GeminiStream { + const callId = + fnCall.id ?? + `${fnCall.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`; + // TODO(jbd): replace with uuid. + const name = fnCall.name || 'undefined_tool_name'; + const args = (fnCall.args || {}) as Record; + + this.pendingToolCalls.push({ callId, name, args }); + const value: ToolCallEvent = { + type: 'tool_call', + status: ToolCallStatus.Pending, + callId, + name, + args, + resultDisplay: undefined, + confirmationDetails: undefined, + }; + yield { + type: GeminiEventType.ToolCallInfo, + value, + }; + } + + private async *handleToolOutcomes( + outcomes: ToolExecutionOutcome[], + ): GeminiStream { + for (const outcome of outcomes) { + const { callId, name, args, result, error, confirmationDetails } = + outcome; + if (error) { + // TODO(jbd): Error handling needs a cleanup. + const errorMessage = error?.message || String(error); + yield { + type: GeminiEventType.Content, + value: `[Error invoking tool ${name}: ${errorMessage}]`, + }; + return; + } + if ( + result && + typeof result === 'object' && + result !== null && + 'error' in result + ) { + const errorMessage = String(result.error); + yield { + type: GeminiEventType.Content, + value: `[Error executing tool ${name}: ${errorMessage}]`, + }; + return; + } + const status = confirmationDetails + ? ToolCallStatus.Confirming + : ToolCallStatus.Invoked; + const value: ToolCallEvent = { + type: 'tool_call', + status, + callId, + name, + args, + resultDisplay: result?.returnDisplay, + confirmationDetails, + }; + yield { + type: GeminiEventType.ToolCallInfo, + value, + }; + } + } + + private buildFunctionResponses(outcomes: ToolExecutionOutcome[]): Part[] { + return outcomes.map((outcome: ToolExecutionOutcome): Part => { + const { name, result, error } = outcome; + const output = { output: result?.llmContent }; + let fnResponse: Record; + + if (error) { + const errorMessage = error?.message || String(error); + fnResponse = { + error: `Invocation failed: ${errorMessage}`, + }; + console.error(`[Turn] Critical error invoking tool ${name}:`, error); + } else if ( + result && + typeof result === 'object' && + result !== null && + 'error' in result + ) { + fnResponse = output; + console.warn( + `[Turn] Tool ${name} returned an error structure:`, + result.error, + ); + } else { + fnResponse = output; + } + + return { + functionResponse: { + name, + id: outcome.callId, + response: fnResponse, + }, + }; + }); + } + + private abortError(): Error { + // TODO(jbd): Move it out of this class. + const error = new Error('Request cancelled by user during stream.'); + error.name = 'AbortError'; + throw error; + } + + getFunctionResponses(): Part[] { + return this.fnResponses; + } + + getDebugResponses(): GenerateContentResponse[] { + return this.debugResponses; + } +}