fix: synchronization between executed tools and turn loops (#488)

This commit is contained in:
Brandon Keiji 2025-05-22 09:51:07 +00:00 committed by GitHub
parent 174fdce7d8
commit a8bfdf2d56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 93 deletions

View File

@ -383,6 +383,7 @@ export const useGeminiStream = (
toolCallRequests.push(event.value); toolCallRequests.push(event.value);
} else if (event.type === ServerGeminiEventType.UserCancelled) { } else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp); handleUserCancelledEvent(userMessageTimestamp);
cancel();
return StreamProcessingStatus.UserCancelled; return StreamProcessingStatus.UserCancelled;
} else if (event.type === ServerGeminiEventType.Error) { } else if (event.type === ServerGeminiEventType.Error) {
handleErrorEvent(event.value, userMessageTimestamp); handleErrorEvent(event.value, userMessageTimestamp);
@ -393,12 +394,9 @@ export const useGeminiStream = (
return StreamProcessingStatus.Completed; return StreamProcessingStatus.Completed;
}; };
const streamingState: StreamingState = isResponding const streamingState: StreamingState =
isResponding || toolCalls.some((t) => t.status === 'awaiting_approval')
? StreamingState.Responding ? StreamingState.Responding
: pendingToolCalls?.tools.some(
(t) => t.status === ToolCallStatus.Confirming,
)
? StreamingState.WaitingForConfirmation
: StreamingState.Idle; : StreamingState.Idle;
const submitQuery = useCallback( const submitQuery = useCallback(

View File

@ -184,19 +184,13 @@ export function useToolScheduler(
useEffect(() => { useEffect(() => {
// effect for executing scheduled tool calls // effect for executing scheduled tool calls
const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled'); if (toolCalls.every((t) => t.status === 'scheduled')) {
const awaitingConfirmation = toolCalls.some( toolCalls.forEach((c) => {
(t) => t.status === 'awaiting_approval',
);
if (!awaitingConfirmation && scheduledCalls.length) {
scheduledCalls.forEach(async (c) => {
const callId = c.request.callId; const callId = c.request.callId;
try {
setToolCalls(setStatus(c.request.callId, 'executing')); setToolCalls(setStatus(c.request.callId, 'executing'));
const result = await c.tool.execute( c.tool
c.request.args, .execute(c.request.args, abortController.signal)
abortController.signal, .then((result) => {
);
const functionResponse: Part = { const functionResponse: Part = {
functionResponse: { functionResponse: {
name: c.request.name, name: c.request.name,
@ -211,7 +205,8 @@ export function useToolScheduler(
error: undefined, error: undefined,
}; };
setToolCalls(setStatus(callId, 'success', response)); setToolCalls(setStatus(callId, 'success', response));
} catch (e: unknown) { })
.catch((e) =>
setToolCalls( setToolCalls(
setStatus( setStatus(
callId, callId,
@ -221,24 +216,23 @@ export function useToolScheduler(
e instanceof Error ? e : new Error(String(e)), e instanceof Error ? e : new Error(String(e)),
), ),
), ),
),
); );
}
}); });
} }
}, [toolCalls, toolRegistry, abortController.signal]); }, [toolCalls, toolRegistry, abortController.signal]);
useEffect(() => { useEffect(() => {
const completedTools = toolCalls.filter( const allDone = toolCalls.every(
(t) => (t) =>
t.status === 'success' || t.status === 'success' ||
t.status === 'error' || t.status === 'error' ||
t.status === 'cancelled', t.status === 'cancelled',
); );
const allDone = completedTools.length === toolCalls.length;
if (toolCalls.length && allDone) { if (toolCalls.length && allDone) {
onComplete(completedTools);
setToolCalls([]); setToolCalls([]);
setAbortController(new AbortController()); onComplete(toolCalls);
setAbortController(() => new AbortController());
} }
}, [toolCalls, onComplete]); }, [toolCalls, onComplete]);

View File

@ -16,7 +16,7 @@ import {
} from '@google/genai'; } from '@google/genai';
import process from 'node:process'; import process from 'node:process';
import { getFolderStructure } from '../utils/getFolderStructure.js'; import { getFolderStructure } from '../utils/getFolderStructure.js';
import { Turn, ServerGeminiStreamEvent, GeminiEventType } from './turn.js'; import { Turn, ServerGeminiStreamEvent } from './turn.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { getCoreSystemPrompt } from './prompts.js'; import { getCoreSystemPrompt } from './prompts.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { ReadManyFilesTool } from '../tools/read-many-files.js';
@ -153,44 +153,24 @@ export class GeminiClient {
chat: Chat, chat: Chat,
request: PartListUnion, request: PartListUnion,
signal?: AbortSignal, signal?: AbortSignal,
turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent> { ): AsyncGenerator<ServerGeminiStreamEvent> {
let turns = 0; if (!turns) {
while (turns < this.MAX_TURNS) { return;
turns++; }
const turn = new Turn(chat); const turn = new Turn(chat);
const resultStream = turn.run(request, signal); const resultStream = turn.run(request, signal);
let seenError = false;
for await (const event of resultStream) { for await (const event of resultStream) {
seenError =
seenError === false ? false : event.type === GeminiEventType.Error;
yield event; yield event;
} }
if (!turn.pendingToolCalls.length) {
const confirmations = turn.getConfirmationDetails();
if (confirmations.length > 0) {
break;
}
const fnResponses = turn.getFunctionResponses();
if (fnResponses.length === 0) {
const nextSpeakerCheck = await checkNextSpeaker(chat, this); const nextSpeakerCheck = await checkNextSpeaker(chat, this);
if (nextSpeakerCheck?.next_speaker === 'model') { if (nextSpeakerCheck?.next_speaker === 'model') {
request = [{ text: 'Please continue.' }]; const nextRequest = [{ text: 'Please continue.' }];
continue; return this.sendMessageStream(chat, nextRequest, signal, turns - 1);
} else {
break;
} }
} }
request = fnResponses;
if (seenError) {
// We saw an error, lets stop processing to prevent unexpected consequences.
break;
}
}
if (turns >= this.MAX_TURNS) {
console.warn('sendMessageStream: Reached maximum tool call turns limit.');
}
} }
async generateJson( async generateJson(

View File

@ -106,19 +106,15 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context. // A turn manages the agentic loop turn within the server context.
export class Turn { export class Turn {
private pendingToolCalls: Array<{ readonly pendingToolCalls: Array<{
callId: string; callId: string;
name: string; name: string;
args: Record<string, unknown>; args: Record<string, unknown>;
}>; }>;
private fnResponses: Part[];
private confirmationDetails: ToolCallConfirmationDetails[];
private debugResponses: GenerateContentResponse[]; private debugResponses: GenerateContentResponse[];
constructor(private readonly chat: Chat) { constructor(private readonly chat: Chat) {
this.pendingToolCalls = []; this.pendingToolCalls = [];
this.fnResponses = [];
this.confirmationDetails = [];
this.debugResponses = []; this.debugResponses = [];
} }
// The run method yields simpler events suitable for server logic // The run method yields simpler events suitable for server logic
@ -182,14 +178,6 @@ export class Turn {
return { type: GeminiEventType.ToolCallRequest, value }; return { type: GeminiEventType.ToolCallRequest, value };
} }
getConfirmationDetails(): ToolCallConfirmationDetails[] {
return this.confirmationDetails;
}
getFunctionResponses(): Part[] {
return this.fnResponses;
}
getDebugResponses(): GenerateContentResponse[] { getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses; return this.debugResponses;
} }