fix: synchronization between executed tools and turn loops (#488)

This commit is contained in:
Brandon Keiji 2025-05-22 09:51:07 +00:00 committed by GitHub
parent 174fdce7d8
commit a8bfdf2d56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 93 deletions

View File

@ -383,6 +383,7 @@ export const useGeminiStream = (
toolCallRequests.push(event.value); toolCallRequests.push(event.value);
} else if (event.type === ServerGeminiEventType.UserCancelled) { } else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp); handleUserCancelledEvent(userMessageTimestamp);
cancel();
return StreamProcessingStatus.UserCancelled; return StreamProcessingStatus.UserCancelled;
} else if (event.type === ServerGeminiEventType.Error) { } else if (event.type === ServerGeminiEventType.Error) {
handleErrorEvent(event.value, userMessageTimestamp); handleErrorEvent(event.value, userMessageTimestamp);
@ -393,12 +394,9 @@ export const useGeminiStream = (
return StreamProcessingStatus.Completed; return StreamProcessingStatus.Completed;
}; };
const streamingState: StreamingState = isResponding const streamingState: StreamingState =
? StreamingState.Responding isResponding || toolCalls.some((t) => t.status === 'awaiting_approval')
: pendingToolCalls?.tools.some( ? StreamingState.Responding
(t) => t.status === ToolCallStatus.Confirming,
)
? StreamingState.WaitingForConfirmation
: StreamingState.Idle; : StreamingState.Idle;
const submitQuery = useCallback( const submitQuery = useCallback(

View File

@ -184,61 +184,55 @@ export function useToolScheduler(
useEffect(() => { useEffect(() => {
// effect for executing scheduled tool calls // effect for executing scheduled tool calls
const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled'); if (toolCalls.every((t) => t.status === 'scheduled')) {
const awaitingConfirmation = toolCalls.some( toolCalls.forEach((c) => {
(t) => t.status === 'awaiting_approval',
);
if (!awaitingConfirmation && scheduledCalls.length) {
scheduledCalls.forEach(async (c) => {
const callId = c.request.callId; const callId = c.request.callId;
try { setToolCalls(setStatus(c.request.callId, 'executing'));
setToolCalls(setStatus(c.request.callId, 'executing')); c.tool
const result = await c.tool.execute( .execute(c.request.args, abortController.signal)
c.request.args, .then((result) => {
abortController.signal, const functionResponse: Part = {
); functionResponse: {
const functionResponse: Part = { name: c.request.name,
functionResponse: { id: callId,
name: c.request.name, response: { output: result.llmContent },
id: callId, },
response: { output: result.llmContent }, };
}, const response: ToolCallResponseInfo = {
};
const response: ToolCallResponseInfo = {
callId,
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
setToolCalls(setStatus(callId, 'success', response));
} catch (e: unknown) {
setToolCalls(
setStatus(
callId, callId,
'error', responsePart: functionResponse,
toolErrorResponse( resultDisplay: result.returnDisplay,
c.request, error: undefined,
e instanceof Error ? e : new Error(String(e)), };
setToolCalls(setStatus(callId, 'success', response));
})
.catch((e) =>
setToolCalls(
setStatus(
callId,
'error',
toolErrorResponse(
c.request,
e instanceof Error ? e : new Error(String(e)),
),
), ),
), ),
); );
}
}); });
} }
}, [toolCalls, toolRegistry, abortController.signal]); }, [toolCalls, toolRegistry, abortController.signal]);
useEffect(() => { useEffect(() => {
const completedTools = toolCalls.filter( const allDone = toolCalls.every(
(t) => (t) =>
t.status === 'success' || t.status === 'success' ||
t.status === 'error' || t.status === 'error' ||
t.status === 'cancelled', t.status === 'cancelled',
); );
const allDone = completedTools.length === toolCalls.length;
if (toolCalls.length && allDone) { if (toolCalls.length && allDone) {
onComplete(completedTools);
setToolCalls([]); setToolCalls([]);
setAbortController(new AbortController()); onComplete(toolCalls);
setAbortController(() => new AbortController());
} }
}, [toolCalls, onComplete]); }, [toolCalls, onComplete]);

View File

@ -16,7 +16,7 @@ import {
} from '@google/genai'; } from '@google/genai';
import process from 'node:process'; import process from 'node:process';
import { getFolderStructure } from '../utils/getFolderStructure.js'; import { getFolderStructure } from '../utils/getFolderStructure.js';
import { Turn, ServerGeminiStreamEvent, GeminiEventType } from './turn.js'; import { Turn, ServerGeminiStreamEvent } from './turn.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { getCoreSystemPrompt } from './prompts.js'; import { getCoreSystemPrompt } from './prompts.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { ReadManyFilesTool } from '../tools/read-many-files.js';
@ -153,43 +153,23 @@ export class GeminiClient {
chat: Chat, chat: Chat,
request: PartListUnion, request: PartListUnion,
signal?: AbortSignal, signal?: AbortSignal,
turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent> { ): AsyncGenerator<ServerGeminiStreamEvent> {
let turns = 0; if (!turns) {
while (turns < this.MAX_TURNS) { return;
turns++;
const turn = new Turn(chat);
const resultStream = turn.run(request, signal);
let seenError = false;
for await (const event of resultStream) {
seenError =
seenError === false ? false : event.type === GeminiEventType.Error;
yield event;
}
const confirmations = turn.getConfirmationDetails();
if (confirmations.length > 0) {
break;
}
const fnResponses = turn.getFunctionResponses();
if (fnResponses.length === 0) {
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
if (nextSpeakerCheck?.next_speaker === 'model') {
request = [{ text: 'Please continue.' }];
continue;
} else {
break;
}
}
request = fnResponses;
if (seenError) {
// We saw an error, lets stop processing to prevent unexpected consequences.
break;
}
} }
if (turns >= this.MAX_TURNS) {
console.warn('sendMessageStream: Reached maximum tool call turns limit.'); const turn = new Turn(chat);
const resultStream = turn.run(request, signal);
for await (const event of resultStream) {
yield event;
}
if (!turn.pendingToolCalls.length) {
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
return this.sendMessageStream(chat, nextRequest, signal, turns - 1);
}
} }
} }

View File

@ -106,19 +106,15 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context. // A turn manages the agentic loop turn within the server context.
export class Turn { export class Turn {
private pendingToolCalls: Array<{ readonly pendingToolCalls: Array<{
callId: string; callId: string;
name: string; name: string;
args: Record<string, unknown>; args: Record<string, unknown>;
}>; }>;
private fnResponses: Part[];
private confirmationDetails: ToolCallConfirmationDetails[];
private debugResponses: GenerateContentResponse[]; private debugResponses: GenerateContentResponse[];
constructor(private readonly chat: Chat) { constructor(private readonly chat: Chat) {
this.pendingToolCalls = []; this.pendingToolCalls = [];
this.fnResponses = [];
this.confirmationDetails = [];
this.debugResponses = []; this.debugResponses = [];
} }
// The run method yields simpler events suitable for server logic // The run method yields simpler events suitable for server logic
@ -182,14 +178,6 @@ export class Turn {
return { type: GeminiEventType.ToolCallRequest, value }; return { type: GeminiEventType.ToolCallRequest, value };
} }
getConfirmationDetails(): ToolCallConfirmationDetails[] {
return this.confirmationDetails;
}
getFunctionResponses(): Part[] {
return this.fnResponses;
}
getDebugResponses(): GenerateContentResponse[] { getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses; return this.debugResponses;
} }