fix: synchronization between executed tools and turn loops (#488)
This commit is contained in:
parent
174fdce7d8
commit
a8bfdf2d56
|
@ -383,6 +383,7 @@ export const useGeminiStream = (
|
||||||
toolCallRequests.push(event.value);
|
toolCallRequests.push(event.value);
|
||||||
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
||||||
handleUserCancelledEvent(userMessageTimestamp);
|
handleUserCancelledEvent(userMessageTimestamp);
|
||||||
|
cancel();
|
||||||
return StreamProcessingStatus.UserCancelled;
|
return StreamProcessingStatus.UserCancelled;
|
||||||
} else if (event.type === ServerGeminiEventType.Error) {
|
} else if (event.type === ServerGeminiEventType.Error) {
|
||||||
handleErrorEvent(event.value, userMessageTimestamp);
|
handleErrorEvent(event.value, userMessageTimestamp);
|
||||||
|
@ -393,12 +394,9 @@ export const useGeminiStream = (
|
||||||
return StreamProcessingStatus.Completed;
|
return StreamProcessingStatus.Completed;
|
||||||
};
|
};
|
||||||
|
|
||||||
const streamingState: StreamingState = isResponding
|
const streamingState: StreamingState =
|
||||||
? StreamingState.Responding
|
isResponding || toolCalls.some((t) => t.status === 'awaiting_approval')
|
||||||
: pendingToolCalls?.tools.some(
|
? StreamingState.Responding
|
||||||
(t) => t.status === ToolCallStatus.Confirming,
|
|
||||||
)
|
|
||||||
? StreamingState.WaitingForConfirmation
|
|
||||||
: StreamingState.Idle;
|
: StreamingState.Idle;
|
||||||
|
|
||||||
const submitQuery = useCallback(
|
const submitQuery = useCallback(
|
||||||
|
|
|
@ -184,61 +184,55 @@ export function useToolScheduler(
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// effect for executing scheduled tool calls
|
// effect for executing scheduled tool calls
|
||||||
const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled');
|
if (toolCalls.every((t) => t.status === 'scheduled')) {
|
||||||
const awaitingConfirmation = toolCalls.some(
|
toolCalls.forEach((c) => {
|
||||||
(t) => t.status === 'awaiting_approval',
|
|
||||||
);
|
|
||||||
if (!awaitingConfirmation && scheduledCalls.length) {
|
|
||||||
scheduledCalls.forEach(async (c) => {
|
|
||||||
const callId = c.request.callId;
|
const callId = c.request.callId;
|
||||||
try {
|
setToolCalls(setStatus(c.request.callId, 'executing'));
|
||||||
setToolCalls(setStatus(c.request.callId, 'executing'));
|
c.tool
|
||||||
const result = await c.tool.execute(
|
.execute(c.request.args, abortController.signal)
|
||||||
c.request.args,
|
.then((result) => {
|
||||||
abortController.signal,
|
const functionResponse: Part = {
|
||||||
);
|
functionResponse: {
|
||||||
const functionResponse: Part = {
|
name: c.request.name,
|
||||||
functionResponse: {
|
id: callId,
|
||||||
name: c.request.name,
|
response: { output: result.llmContent },
|
||||||
id: callId,
|
},
|
||||||
response: { output: result.llmContent },
|
};
|
||||||
},
|
const response: ToolCallResponseInfo = {
|
||||||
};
|
|
||||||
const response: ToolCallResponseInfo = {
|
|
||||||
callId,
|
|
||||||
responsePart: functionResponse,
|
|
||||||
resultDisplay: result.returnDisplay,
|
|
||||||
error: undefined,
|
|
||||||
};
|
|
||||||
setToolCalls(setStatus(callId, 'success', response));
|
|
||||||
} catch (e: unknown) {
|
|
||||||
setToolCalls(
|
|
||||||
setStatus(
|
|
||||||
callId,
|
callId,
|
||||||
'error',
|
responsePart: functionResponse,
|
||||||
toolErrorResponse(
|
resultDisplay: result.returnDisplay,
|
||||||
c.request,
|
error: undefined,
|
||||||
e instanceof Error ? e : new Error(String(e)),
|
};
|
||||||
|
setToolCalls(setStatus(callId, 'success', response));
|
||||||
|
})
|
||||||
|
.catch((e) =>
|
||||||
|
setToolCalls(
|
||||||
|
setStatus(
|
||||||
|
callId,
|
||||||
|
'error',
|
||||||
|
toolErrorResponse(
|
||||||
|
c.request,
|
||||||
|
e instanceof Error ? e : new Error(String(e)),
|
||||||
|
),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [toolCalls, toolRegistry, abortController.signal]);
|
}, [toolCalls, toolRegistry, abortController.signal]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const completedTools = toolCalls.filter(
|
const allDone = toolCalls.every(
|
||||||
(t) =>
|
(t) =>
|
||||||
t.status === 'success' ||
|
t.status === 'success' ||
|
||||||
t.status === 'error' ||
|
t.status === 'error' ||
|
||||||
t.status === 'cancelled',
|
t.status === 'cancelled',
|
||||||
);
|
);
|
||||||
const allDone = completedTools.length === toolCalls.length;
|
|
||||||
if (toolCalls.length && allDone) {
|
if (toolCalls.length && allDone) {
|
||||||
onComplete(completedTools);
|
|
||||||
setToolCalls([]);
|
setToolCalls([]);
|
||||||
setAbortController(new AbortController());
|
onComplete(toolCalls);
|
||||||
|
setAbortController(() => new AbortController());
|
||||||
}
|
}
|
||||||
}, [toolCalls, onComplete]);
|
}, [toolCalls, onComplete]);
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ import {
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import process from 'node:process';
|
import process from 'node:process';
|
||||||
import { getFolderStructure } from '../utils/getFolderStructure.js';
|
import { getFolderStructure } from '../utils/getFolderStructure.js';
|
||||||
import { Turn, ServerGeminiStreamEvent, GeminiEventType } from './turn.js';
|
import { Turn, ServerGeminiStreamEvent } from './turn.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import { getCoreSystemPrompt } from './prompts.js';
|
import { getCoreSystemPrompt } from './prompts.js';
|
||||||
import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
||||||
|
@ -153,43 +153,23 @@ export class GeminiClient {
|
||||||
chat: Chat,
|
chat: Chat,
|
||||||
request: PartListUnion,
|
request: PartListUnion,
|
||||||
signal?: AbortSignal,
|
signal?: AbortSignal,
|
||||||
|
turns: number = this.MAX_TURNS,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||||
let turns = 0;
|
if (!turns) {
|
||||||
while (turns < this.MAX_TURNS) {
|
return;
|
||||||
turns++;
|
|
||||||
const turn = new Turn(chat);
|
|
||||||
const resultStream = turn.run(request, signal);
|
|
||||||
let seenError = false;
|
|
||||||
for await (const event of resultStream) {
|
|
||||||
seenError =
|
|
||||||
seenError === false ? false : event.type === GeminiEventType.Error;
|
|
||||||
yield event;
|
|
||||||
}
|
|
||||||
|
|
||||||
const confirmations = turn.getConfirmationDetails();
|
|
||||||
if (confirmations.length > 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const fnResponses = turn.getFunctionResponses();
|
|
||||||
if (fnResponses.length === 0) {
|
|
||||||
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
|
|
||||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
|
||||||
request = [{ text: 'Please continue.' }];
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
request = fnResponses;
|
|
||||||
|
|
||||||
if (seenError) {
|
|
||||||
// We saw an error, lets stop processing to prevent unexpected consequences.
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (turns >= this.MAX_TURNS) {
|
|
||||||
console.warn('sendMessageStream: Reached maximum tool call turns limit.');
|
const turn = new Turn(chat);
|
||||||
|
const resultStream = turn.run(request, signal);
|
||||||
|
for await (const event of resultStream) {
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
if (!turn.pendingToolCalls.length) {
|
||||||
|
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
|
||||||
|
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||||
|
const nextRequest = [{ text: 'Please continue.' }];
|
||||||
|
return this.sendMessageStream(chat, nextRequest, signal, turns - 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -106,19 +106,15 @@ export type ServerGeminiStreamEvent =
|
||||||
|
|
||||||
// A turn manages the agentic loop turn within the server context.
|
// A turn manages the agentic loop turn within the server context.
|
||||||
export class Turn {
|
export class Turn {
|
||||||
private pendingToolCalls: Array<{
|
readonly pendingToolCalls: Array<{
|
||||||
callId: string;
|
callId: string;
|
||||||
name: string;
|
name: string;
|
||||||
args: Record<string, unknown>;
|
args: Record<string, unknown>;
|
||||||
}>;
|
}>;
|
||||||
private fnResponses: Part[];
|
|
||||||
private confirmationDetails: ToolCallConfirmationDetails[];
|
|
||||||
private debugResponses: GenerateContentResponse[];
|
private debugResponses: GenerateContentResponse[];
|
||||||
|
|
||||||
constructor(private readonly chat: Chat) {
|
constructor(private readonly chat: Chat) {
|
||||||
this.pendingToolCalls = [];
|
this.pendingToolCalls = [];
|
||||||
this.fnResponses = [];
|
|
||||||
this.confirmationDetails = [];
|
|
||||||
this.debugResponses = [];
|
this.debugResponses = [];
|
||||||
}
|
}
|
||||||
// The run method yields simpler events suitable for server logic
|
// The run method yields simpler events suitable for server logic
|
||||||
|
@ -182,14 +178,6 @@ export class Turn {
|
||||||
return { type: GeminiEventType.ToolCallRequest, value };
|
return { type: GeminiEventType.ToolCallRequest, value };
|
||||||
}
|
}
|
||||||
|
|
||||||
getConfirmationDetails(): ToolCallConfirmationDetails[] {
|
|
||||||
return this.confirmationDetails;
|
|
||||||
}
|
|
||||||
|
|
||||||
getFunctionResponses(): Part[] {
|
|
||||||
return this.fnResponses;
|
|
||||||
}
|
|
||||||
|
|
||||||
getDebugResponses(): GenerateContentResponse[] {
|
getDebugResponses(): GenerateContentResponse[] {
|
||||||
return this.debugResponses;
|
return this.debugResponses;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue