fix: synchronization between executed tools and turn loops (#488)
This commit is contained in:
parent
174fdce7d8
commit
a8bfdf2d56
|
@ -383,6 +383,7 @@ export const useGeminiStream = (
|
|||
toolCallRequests.push(event.value);
|
||||
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
||||
handleUserCancelledEvent(userMessageTimestamp);
|
||||
cancel();
|
||||
return StreamProcessingStatus.UserCancelled;
|
||||
} else if (event.type === ServerGeminiEventType.Error) {
|
||||
handleErrorEvent(event.value, userMessageTimestamp);
|
||||
|
@ -393,12 +394,9 @@ export const useGeminiStream = (
|
|||
return StreamProcessingStatus.Completed;
|
||||
};
|
||||
|
||||
const streamingState: StreamingState = isResponding
|
||||
? StreamingState.Responding
|
||||
: pendingToolCalls?.tools.some(
|
||||
(t) => t.status === ToolCallStatus.Confirming,
|
||||
)
|
||||
? StreamingState.WaitingForConfirmation
|
||||
const streamingState: StreamingState =
|
||||
isResponding || toolCalls.some((t) => t.status === 'awaiting_approval')
|
||||
? StreamingState.Responding
|
||||
: StreamingState.Idle;
|
||||
|
||||
const submitQuery = useCallback(
|
||||
|
|
|
@ -184,61 +184,55 @@ export function useToolScheduler(
|
|||
|
||||
useEffect(() => {
|
||||
// effect for executing scheduled tool calls
|
||||
const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled');
|
||||
const awaitingConfirmation = toolCalls.some(
|
||||
(t) => t.status === 'awaiting_approval',
|
||||
);
|
||||
if (!awaitingConfirmation && scheduledCalls.length) {
|
||||
scheduledCalls.forEach(async (c) => {
|
||||
if (toolCalls.every((t) => t.status === 'scheduled')) {
|
||||
toolCalls.forEach((c) => {
|
||||
const callId = c.request.callId;
|
||||
try {
|
||||
setToolCalls(setStatus(c.request.callId, 'executing'));
|
||||
const result = await c.tool.execute(
|
||||
c.request.args,
|
||||
abortController.signal,
|
||||
);
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
name: c.request.name,
|
||||
id: callId,
|
||||
response: { output: result.llmContent },
|
||||
},
|
||||
};
|
||||
const response: ToolCallResponseInfo = {
|
||||
callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay: result.returnDisplay,
|
||||
error: undefined,
|
||||
};
|
||||
setToolCalls(setStatus(callId, 'success', response));
|
||||
} catch (e: unknown) {
|
||||
setToolCalls(
|
||||
setStatus(
|
||||
setToolCalls(setStatus(c.request.callId, 'executing'));
|
||||
c.tool
|
||||
.execute(c.request.args, abortController.signal)
|
||||
.then((result) => {
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
name: c.request.name,
|
||||
id: callId,
|
||||
response: { output: result.llmContent },
|
||||
},
|
||||
};
|
||||
const response: ToolCallResponseInfo = {
|
||||
callId,
|
||||
'error',
|
||||
toolErrorResponse(
|
||||
c.request,
|
||||
e instanceof Error ? e : new Error(String(e)),
|
||||
responsePart: functionResponse,
|
||||
resultDisplay: result.returnDisplay,
|
||||
error: undefined,
|
||||
};
|
||||
setToolCalls(setStatus(callId, 'success', response));
|
||||
})
|
||||
.catch((e) =>
|
||||
setToolCalls(
|
||||
setStatus(
|
||||
callId,
|
||||
'error',
|
||||
toolErrorResponse(
|
||||
c.request,
|
||||
e instanceof Error ? e : new Error(String(e)),
|
||||
),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [toolCalls, toolRegistry, abortController.signal]);
|
||||
|
||||
useEffect(() => {
|
||||
const completedTools = toolCalls.filter(
|
||||
const allDone = toolCalls.every(
|
||||
(t) =>
|
||||
t.status === 'success' ||
|
||||
t.status === 'error' ||
|
||||
t.status === 'cancelled',
|
||||
);
|
||||
const allDone = completedTools.length === toolCalls.length;
|
||||
if (toolCalls.length && allDone) {
|
||||
onComplete(completedTools);
|
||||
setToolCalls([]);
|
||||
setAbortController(new AbortController());
|
||||
onComplete(toolCalls);
|
||||
setAbortController(() => new AbortController());
|
||||
}
|
||||
}, [toolCalls, onComplete]);
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ import {
|
|||
} from '@google/genai';
|
||||
import process from 'node:process';
|
||||
import { getFolderStructure } from '../utils/getFolderStructure.js';
|
||||
import { Turn, ServerGeminiStreamEvent, GeminiEventType } from './turn.js';
|
||||
import { Turn, ServerGeminiStreamEvent } from './turn.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { getCoreSystemPrompt } from './prompts.js';
|
||||
import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
||||
|
@ -153,43 +153,23 @@ export class GeminiClient {
|
|||
chat: Chat,
|
||||
request: PartListUnion,
|
||||
signal?: AbortSignal,
|
||||
turns: number = this.MAX_TURNS,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
let turns = 0;
|
||||
while (turns < this.MAX_TURNS) {
|
||||
turns++;
|
||||
const turn = new Turn(chat);
|
||||
const resultStream = turn.run(request, signal);
|
||||
let seenError = false;
|
||||
for await (const event of resultStream) {
|
||||
seenError =
|
||||
seenError === false ? false : event.type === GeminiEventType.Error;
|
||||
yield event;
|
||||
}
|
||||
|
||||
const confirmations = turn.getConfirmationDetails();
|
||||
if (confirmations.length > 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
const fnResponses = turn.getFunctionResponses();
|
||||
if (fnResponses.length === 0) {
|
||||
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
request = [{ text: 'Please continue.' }];
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
request = fnResponses;
|
||||
|
||||
if (seenError) {
|
||||
// We saw an error, lets stop processing to prevent unexpected consequences.
|
||||
break;
|
||||
}
|
||||
if (!turns) {
|
||||
return;
|
||||
}
|
||||
if (turns >= this.MAX_TURNS) {
|
||||
console.warn('sendMessageStream: Reached maximum tool call turns limit.');
|
||||
|
||||
const turn = new Turn(chat);
|
||||
const resultStream = turn.run(request, signal);
|
||||
for await (const event of resultStream) {
|
||||
yield event;
|
||||
}
|
||||
if (!turn.pendingToolCalls.length) {
|
||||
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
return this.sendMessageStream(chat, nextRequest, signal, turns - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -106,19 +106,15 @@ export type ServerGeminiStreamEvent =
|
|||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
private pendingToolCalls: Array<{
|
||||
readonly pendingToolCalls: Array<{
|
||||
callId: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
private fnResponses: Part[];
|
||||
private confirmationDetails: ToolCallConfirmationDetails[];
|
||||
private debugResponses: GenerateContentResponse[];
|
||||
|
||||
constructor(private readonly chat: Chat) {
|
||||
this.pendingToolCalls = [];
|
||||
this.fnResponses = [];
|
||||
this.confirmationDetails = [];
|
||||
this.debugResponses = [];
|
||||
}
|
||||
// The run method yields simpler events suitable for server logic
|
||||
|
@ -182,14 +178,6 @@ export class Turn {
|
|||
return { type: GeminiEventType.ToolCallRequest, value };
|
||||
}
|
||||
|
||||
getConfirmationDetails(): ToolCallConfirmationDetails[] {
|
||||
return this.confirmationDetails;
|
||||
}
|
||||
|
||||
getFunctionResponses(): Part[] {
|
||||
return this.fnResponses;
|
||||
}
|
||||
|
||||
getDebugResponses(): GenerateContentResponse[] {
|
||||
return this.debugResponses;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue