diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 324a4ffa..d3ecad95 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -383,6 +383,7 @@ export const useGeminiStream = ( toolCallRequests.push(event.value); } else if (event.type === ServerGeminiEventType.UserCancelled) { handleUserCancelledEvent(userMessageTimestamp); + cancel(); return StreamProcessingStatus.UserCancelled; } else if (event.type === ServerGeminiEventType.Error) { handleErrorEvent(event.value, userMessageTimestamp); @@ -393,12 +394,9 @@ export const useGeminiStream = ( return StreamProcessingStatus.Completed; }; - const streamingState: StreamingState = isResponding - ? StreamingState.Responding - : pendingToolCalls?.tools.some( - (t) => t.status === ToolCallStatus.Confirming, - ) - ? StreamingState.WaitingForConfirmation + const streamingState: StreamingState = + isResponding || toolCalls.some((t) => t.status === 'awaiting_approval') + ? StreamingState.Responding : StreamingState.Idle; const submitQuery = useCallback( diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts index fde632df..e14241b6 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.ts @@ -184,61 +184,55 @@ export function useToolScheduler( useEffect(() => { // effect for executing scheduled tool calls - const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled'); - const awaitingConfirmation = toolCalls.some( - (t) => t.status === 'awaiting_approval', - ); - if (!awaitingConfirmation && scheduledCalls.length) { - scheduledCalls.forEach(async (c) => { + if (toolCalls.every((t) => t.status === 'scheduled')) { + toolCalls.forEach((c) => { const callId = c.request.callId; - try { - setToolCalls(setStatus(c.request.callId, 'executing')); - const result = await c.tool.execute( - c.request.args, - abortController.signal, - ); - const functionResponse: Part = { - functionResponse: { - name: c.request.name, - id: callId, - response: { output: result.llmContent }, - }, - }; - const response: ToolCallResponseInfo = { - callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - setToolCalls(setStatus(callId, 'success', response)); - } catch (e: unknown) { - setToolCalls( - setStatus( + setToolCalls(setStatus(c.request.callId, 'executing')); + c.tool + .execute(c.request.args, abortController.signal) + .then((result) => { + const functionResponse: Part = { + functionResponse: { + name: c.request.name, + id: callId, + response: { output: result.llmContent }, + }, + }; + const response: ToolCallResponseInfo = { callId, - 'error', - toolErrorResponse( - c.request, - e instanceof Error ? e : new Error(String(e)), + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + setToolCalls(setStatus(callId, 'success', response)); + }) + .catch((e) => + setToolCalls( + setStatus( + callId, + 'error', + toolErrorResponse( + c.request, + e instanceof Error ? e : new Error(String(e)), + ), ), ), ); - } }); } }, [toolCalls, toolRegistry, abortController.signal]); useEffect(() => { - const completedTools = toolCalls.filter( + const allDone = toolCalls.every( (t) => t.status === 'success' || t.status === 'error' || t.status === 'cancelled', ); - const allDone = completedTools.length === toolCalls.length; if (toolCalls.length && allDone) { - onComplete(completedTools); setToolCalls([]); - setAbortController(new AbortController()); + onComplete(toolCalls); + setAbortController(() => new AbortController()); } }, [toolCalls, onComplete]); diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 489e2a0b..85850da8 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -16,7 +16,7 @@ import { } from '@google/genai'; import process from 'node:process'; import { getFolderStructure } from '../utils/getFolderStructure.js'; -import { Turn, ServerGeminiStreamEvent, GeminiEventType } from './turn.js'; +import { Turn, ServerGeminiStreamEvent } from './turn.js'; import { Config } from '../config/config.js'; import { getCoreSystemPrompt } from './prompts.js'; import { ReadManyFilesTool } from '../tools/read-many-files.js'; @@ -153,43 +153,23 @@ export class GeminiClient { chat: Chat, request: PartListUnion, signal?: AbortSignal, + turns: number = this.MAX_TURNS, ): AsyncGenerator { - let turns = 0; - while (turns < this.MAX_TURNS) { - turns++; - const turn = new Turn(chat); - const resultStream = turn.run(request, signal); - let seenError = false; - for await (const event of resultStream) { - seenError = - seenError === false ? false : event.type === GeminiEventType.Error; - yield event; - } - - const confirmations = turn.getConfirmationDetails(); - if (confirmations.length > 0) { - break; - } - - const fnResponses = turn.getFunctionResponses(); - if (fnResponses.length === 0) { - const nextSpeakerCheck = await checkNextSpeaker(chat, this); - if (nextSpeakerCheck?.next_speaker === 'model') { - request = [{ text: 'Please continue.' }]; - continue; - } else { - break; - } - } - request = fnResponses; - - if (seenError) { - // We saw an error, lets stop processing to prevent unexpected consequences. - break; - } + if (!turns) { + return; } - if (turns >= this.MAX_TURNS) { - console.warn('sendMessageStream: Reached maximum tool call turns limit.'); + + const turn = new Turn(chat); + const resultStream = turn.run(request, signal); + for await (const event of resultStream) { + yield event; + } + if (!turn.pendingToolCalls.length) { + const nextSpeakerCheck = await checkNextSpeaker(chat, this); + if (nextSpeakerCheck?.next_speaker === 'model') { + const nextRequest = [{ text: 'Please continue.' }]; + return this.sendMessageStream(chat, nextRequest, signal, turns - 1); + } } } diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index 7b2a96f9..38932041 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -106,19 +106,15 @@ export type ServerGeminiStreamEvent = // A turn manages the agentic loop turn within the server context. export class Turn { - private pendingToolCalls: Array<{ + readonly pendingToolCalls: Array<{ callId: string; name: string; args: Record; }>; - private fnResponses: Part[]; - private confirmationDetails: ToolCallConfirmationDetails[]; private debugResponses: GenerateContentResponse[]; constructor(private readonly chat: Chat) { this.pendingToolCalls = []; - this.fnResponses = []; - this.confirmationDetails = []; this.debugResponses = []; } // The run method yields simpler events suitable for server logic @@ -182,14 +178,6 @@ export class Turn { return { type: GeminiEventType.ToolCallRequest, value }; } - getConfirmationDetails(): ToolCallConfirmationDetails[] { - return this.confirmationDetails; - } - - getFunctionResponses(): Part[] { - return this.fnResponses; - } - getDebugResponses(): GenerateContentResponse[] { return this.debugResponses; }