fix: synchronization between executed tools and turn loops (#488)

This commit is contained in:
Brandon Keiji 2025-05-22 09:51:07 +00:00 committed by GitHub
parent 174fdce7d8
commit a8bfdf2d56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 93 deletions

View File

@ -383,6 +383,7 @@ export const useGeminiStream = (
toolCallRequests.push(event.value);
} else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp);
cancel();
return StreamProcessingStatus.UserCancelled;
} else if (event.type === ServerGeminiEventType.Error) {
handleErrorEvent(event.value, userMessageTimestamp);
@ -393,12 +394,9 @@ export const useGeminiStream = (
return StreamProcessingStatus.Completed;
};
const streamingState: StreamingState = isResponding
? StreamingState.Responding
: pendingToolCalls?.tools.some(
(t) => t.status === ToolCallStatus.Confirming,
)
? StreamingState.WaitingForConfirmation
const streamingState: StreamingState =
isResponding || toolCalls.some((t) => t.status === 'awaiting_approval')
? StreamingState.Responding
: StreamingState.Idle;
const submitQuery = useCallback(

View File

@ -184,61 +184,55 @@ export function useToolScheduler(
useEffect(() => {
// effect for executing scheduled tool calls
const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled');
const awaitingConfirmation = toolCalls.some(
(t) => t.status === 'awaiting_approval',
);
if (!awaitingConfirmation && scheduledCalls.length) {
scheduledCalls.forEach(async (c) => {
if (toolCalls.every((t) => t.status === 'scheduled')) {
toolCalls.forEach((c) => {
const callId = c.request.callId;
try {
setToolCalls(setStatus(c.request.callId, 'executing'));
const result = await c.tool.execute(
c.request.args,
abortController.signal,
);
const functionResponse: Part = {
functionResponse: {
name: c.request.name,
id: callId,
response: { output: result.llmContent },
},
};
const response: ToolCallResponseInfo = {
callId,
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
setToolCalls(setStatus(callId, 'success', response));
} catch (e: unknown) {
setToolCalls(
setStatus(
setToolCalls(setStatus(c.request.callId, 'executing'));
c.tool
.execute(c.request.args, abortController.signal)
.then((result) => {
const functionResponse: Part = {
functionResponse: {
name: c.request.name,
id: callId,
response: { output: result.llmContent },
},
};
const response: ToolCallResponseInfo = {
callId,
'error',
toolErrorResponse(
c.request,
e instanceof Error ? e : new Error(String(e)),
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
setToolCalls(setStatus(callId, 'success', response));
})
.catch((e) =>
setToolCalls(
setStatus(
callId,
'error',
toolErrorResponse(
c.request,
e instanceof Error ? e : new Error(String(e)),
),
),
),
);
}
});
}
}, [toolCalls, toolRegistry, abortController.signal]);
useEffect(() => {
const completedTools = toolCalls.filter(
const allDone = toolCalls.every(
(t) =>
t.status === 'success' ||
t.status === 'error' ||
t.status === 'cancelled',
);
const allDone = completedTools.length === toolCalls.length;
if (toolCalls.length && allDone) {
onComplete(completedTools);
setToolCalls([]);
setAbortController(new AbortController());
onComplete(toolCalls);
setAbortController(() => new AbortController());
}
}, [toolCalls, onComplete]);

View File

@ -16,7 +16,7 @@ import {
} from '@google/genai';
import process from 'node:process';
import { getFolderStructure } from '../utils/getFolderStructure.js';
import { Turn, ServerGeminiStreamEvent, GeminiEventType } from './turn.js';
import { Turn, ServerGeminiStreamEvent } from './turn.js';
import { Config } from '../config/config.js';
import { getCoreSystemPrompt } from './prompts.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
@ -153,43 +153,23 @@ export class GeminiClient {
chat: Chat,
request: PartListUnion,
signal?: AbortSignal,
turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent> {
let turns = 0;
while (turns < this.MAX_TURNS) {
turns++;
const turn = new Turn(chat);
const resultStream = turn.run(request, signal);
let seenError = false;
for await (const event of resultStream) {
seenError =
seenError === false ? false : event.type === GeminiEventType.Error;
yield event;
}
const confirmations = turn.getConfirmationDetails();
if (confirmations.length > 0) {
break;
}
const fnResponses = turn.getFunctionResponses();
if (fnResponses.length === 0) {
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
if (nextSpeakerCheck?.next_speaker === 'model') {
request = [{ text: 'Please continue.' }];
continue;
} else {
break;
}
}
request = fnResponses;
if (seenError) {
// We saw an error, lets stop processing to prevent unexpected consequences.
break;
}
if (!turns) {
return;
}
if (turns >= this.MAX_TURNS) {
console.warn('sendMessageStream: Reached maximum tool call turns limit.');
const turn = new Turn(chat);
const resultStream = turn.run(request, signal);
for await (const event of resultStream) {
yield event;
}
if (!turn.pendingToolCalls.length) {
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
return this.sendMessageStream(chat, nextRequest, signal, turns - 1);
}
}
}

View File

@ -106,19 +106,15 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context.
export class Turn {
private pendingToolCalls: Array<{
readonly pendingToolCalls: Array<{
callId: string;
name: string;
args: Record<string, unknown>;
}>;
private fnResponses: Part[];
private confirmationDetails: ToolCallConfirmationDetails[];
private debugResponses: GenerateContentResponse[];
constructor(private readonly chat: Chat) {
this.pendingToolCalls = [];
this.fnResponses = [];
this.confirmationDetails = [];
this.debugResponses = [];
}
// The run method yields simpler events suitable for server logic
@ -182,14 +178,6 @@ export class Turn {
return { type: GeminiEventType.ToolCallRequest, value };
}
getConfirmationDetails(): ToolCallConfirmationDetails[] {
return this.confirmationDetails;
}
getFunctionResponses(): Part[] {
return this.fnResponses;
}
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}