feat: useToolScheduler hook to manage parallel tool calls (#448)

This commit is contained in:
Brandon Keiji 2025-05-22 05:57:53 +00:00 committed by GitHub
parent efee7c6cce
commit 02eec5c8ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 109 additions and 369 deletions

View File

@ -134,7 +134,7 @@ export const App = ({
cliVersion,
);
const { streamingState, submitQuery, initError, pendingHistoryItem } =
const { streamingState, submitQuery, initError, pendingHistoryItems } =
useGeminiStream(
addItem,
refreshStatic,
@ -209,7 +209,7 @@ export const App = ({
}, [terminalHeight, footerHeight]);
useEffect(() => {
if (!pendingHistoryItem) {
if (!pendingHistoryItems.length) {
return;
}
@ -223,7 +223,7 @@ export const App = ({
if (pendingItemDimensions.height > availableTerminalHeight) {
setStaticNeedsRefresh(true);
}
}, [pendingHistoryItem, availableTerminalHeight, streamingState]);
}, [pendingHistoryItems.length, availableTerminalHeight, streamingState]);
useEffect(() => {
if (streamingState === StreamingState.Idle && staticNeedsRefresh) {
@ -264,17 +264,18 @@ export const App = ({
>
{(item) => item}
</Static>
{pendingHistoryItem && (
<Box ref={pendingHistoryItemRef}>
<Box ref={pendingHistoryItemRef}>
{pendingHistoryItems.map((item, i) => (
<HistoryItemDisplay
key={i}
availableTerminalHeight={availableTerminalHeight}
// TODO(taehykim): It seems like references to ids aren't necessary in
// HistoryItemDisplay. Refactor later. Use a fake id for now.
item={{ ...pendingHistoryItem, id: 0 }}
item={{ ...item, id: 0 }}
isPending={true}
/>
</Box>
)}
))}
</Box>
{showHelp && <Help commands={slashCommands} />}
<Box flexDirection="column" ref={mainControlsRef}>

View File

@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import React, { useMemo } from 'react';
import { Box } from 'ink';
import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js';
import { ToolMessage } from './ToolMessage.js';
@ -19,7 +19,6 @@ interface ToolGroupMessageProps {
// Main component renders the border and maps the tools using ToolMessage
export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
groupId,
toolCalls,
availableTerminalHeight,
}) => {
@ -30,9 +29,13 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
const staticHeight = /* border */ 2 + /* marginBottom */ 1;
const toolAwaitingApproval = useMemo(
() => toolCalls.find((tc) => tc.status === ToolCallStatus.Confirming),
[toolCalls],
);
return (
<Box
key={groupId}
flexDirection="column"
borderStyle="round"
/*
@ -48,7 +51,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
marginBottom={1}
>
{toolCalls.map((tool) => (
<Box key={groupId + '-' + tool.callId} flexDirection="column">
<Box key={tool.callId} flexDirection="column">
<ToolMessage
key={tool.callId}
callId={tool.callId}
@ -60,6 +63,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
availableTerminalHeight={availableTerminalHeight - staticHeight}
/>
{tool.status === ToolCallStatus.Confirming &&
tool.callId === toolAwaitingApproval?.callId &&
tool.confirmationDetails && (
<ToolConfirmationMessage
confirmationDetails={tool.confirmationDetails}

View File

@ -4,34 +4,28 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { useState, useRef, useCallback, useEffect } from 'react';
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { useInput } from 'ink';
import {
GeminiClient,
GeminiEventType as ServerGeminiEventType,
ServerGeminiStreamEvent as GeminiEvent,
ServerGeminiContentEvent as ContentEvent,
ServerGeminiToolCallRequestEvent as ToolCallRequestEvent,
ServerGeminiToolCallResponseEvent as ToolCallResponseEvent,
ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent,
ServerGeminiErrorEvent as ErrorEvent,
getErrorMessage,
isNodeError,
Config,
MessageSenderType,
ServerToolCallConfirmationDetails,
ToolCallConfirmationDetails,
ToolCallResponseInfo,
ToolConfirmationOutcome,
ToolEditConfirmationDetails,
ToolExecuteConfirmationDetails,
ToolResultDisplay,
partListUnionToString,
ToolCallRequestInfo,
} from '@gemini-code/server';
import { type Chat, type PartListUnion, type Part } from '@google/genai';
import {
StreamingState,
IndividualToolCallDisplay,
ToolCallStatus,
HistoryItemWithoutId,
HistoryItemToolGroup,
@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { useLogger } from './useLogger.js';
import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
enum StreamProcessingStatus {
Completed,
@ -65,7 +60,6 @@ export const useGeminiStream = (
handleSlashCommand: (cmd: PartListUnion) => boolean,
shellModeActive: boolean,
) => {
const toolRegistry = config.getToolRegistry();
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const chatSessionRef = useRef<Chat | null>(null);
@ -74,6 +68,25 @@ export const useGeminiStream = (
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
const [toolCalls, schedule, cancel] = useToolScheduler((tools) => {
if (tools.length) {
addItem(mapToDisplay(tools), Date.now());
submitQuery(
tools
.filter(
(t) =>
t.status === 'error' ||
t.status === 'cancelled' ||
t.status === 'success',
)
.map((t) => t.response.responsePart),
);
}
}, config);
const pendingToolCalls = useMemo(
() => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
[toolCalls],
);
const onExec = useCallback(async (done: Promise<void>) => {
setIsResponding(true);
@ -104,6 +117,7 @@ export const useGeminiStream = (
useInput((_input, key) => {
if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
cancel();
}
});
@ -215,157 +229,48 @@ export const useGeminiStream = (
);
};
const updateConfirmingFunctionStatusUI = (
callId: string,
confirmationDetails: ToolCallConfirmationDetails | undefined,
) => {
setPendingHistoryItem((item) =>
item?.type === 'tool_group'
? {
...item,
tools: item.tools.map((tool) =>
tool.callId === callId
? {
...tool,
status: ToolCallStatus.Confirming,
confirmationDetails,
}
: tool,
),
}
: item,
);
};
const wireConfirmationSubmission = (
confirmationDetails: ServerToolCallConfirmationDetails,
): ToolCallConfirmationDetails => {
const originalConfirmationDetails = confirmationDetails.details;
const request = confirmationDetails.request;
const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => {
originalConfirmationDetails.onConfirm(outcome);
if (pendingHistoryItemRef?.current?.type === 'tool_group') {
setPendingHistoryItem((item) =>
item?.type === 'tool_group'
? {
...item,
tools: item.tools.map((tool) =>
tool.callId === request.callId
? {
...tool,
confirmationDetails: undefined,
status: ToolCallStatus.Executing,
}
: tool,
),
}
: item,
);
refreshStatic();
}
if (outcome === ToolConfirmationOutcome.Cancel) {
declineToolExecution(
'User rejected function call.',
ToolCallStatus.Error,
request,
originalConfirmationDetails,
);
} else {
const tool = toolRegistry.getTool(request.name);
if (!tool) {
throw new Error(
`Tool "${request.name}" not found or is not registered.`,
);
}
try {
abortControllerRef.current = new AbortController();
const result = await tool.execute(
request.args,
abortControllerRef.current.signal,
);
if (abortControllerRef.current.signal.aborted) {
declineToolExecution(
partListUnionToString(result.llmContent),
ToolCallStatus.Canceled,
request,
originalConfirmationDetails,
);
return;
}
const functionResponse: Part = {
functionResponse: {
name: request.name,
id: request.callId,
response: { output: result.llmContent },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null);
}
setIsResponding(false);
await submitQuery(functionResponse); // Recursive call
} finally {
if (streamingState !== StreamingState.WaitingForConfirmation) {
abortControllerRef.current = null;
}
}
}
};
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
// or could be a standalone helper if more params are passed.
function declineToolExecution(
declineMessage: string,
status: ToolCallStatus,
request: ServerToolCallConfirmationDetails['request'],
originalDetails: ServerToolCallConfirmationDetails['details'],
) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalDetails) {
resultDisplay = {
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
};
} else {
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
}
const functionResponse: Part = {
functionResponse: {
id: request.callId,
name: request.name,
response: { error: declineMessage },
},
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
// or could be a standalone helper if more params are passed.
// TODO: handle file diff result display stuff
function _declineToolExecution(
declineMessage: string,
status: ToolCallStatus,
request: ServerToolCallConfirmationDetails['request'],
originalDetails: ServerToolCallConfirmationDetails['details'],
) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalDetails) {
resultDisplay = {
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
error: new Error(declineMessage),
};
const history = chatSessionRef.current?.getHistory();
if (history) {
history.push({ role: 'model', parts: [functionResponse] });
}
updateFunctionResponseUI(responseInfo, status);
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null);
}
setIsResponding(false);
} else {
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
}
return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
};
const functionResponse: Part = {
functionResponse: {
id: request.callId,
name: request.name,
response: { error: declineMessage },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
error: new Error(declineMessage),
};
const history = chatSessionRef.current?.getHistory();
if (history) {
history.push({ role: 'model', parts: [functionResponse] });
}
updateFunctionResponseUI(responseInfo, status);
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null);
}
setIsResponding(false);
}
// --- Stream Event Handlers ---
const handleContentEvent = (
@ -419,62 +324,6 @@ export const useGeminiStream = (
return newGeminiMessageBuffer;
};
const handleToolCallRequestEvent = (
eventValue: ToolCallRequestEvent['value'],
userMessageTimestamp: number,
) => {
const { callId, name, args } = eventValue;
const cliTool = toolRegistry.getTool(name);
if (!cliTool) {
console.error(`CLI Tool "${name}" not found!`);
return; // Skip this event if tool is not found
}
if (pendingHistoryItemRef.current?.type !== 'tool_group') {
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
}
setPendingHistoryItem({ type: 'tool_group', tools: [] });
}
let description: string;
try {
description = cliTool.getDescription(args);
} catch (e) {
description = `Error: Unable to get description: ${getErrorMessage(e)}`;
}
const toolCallDisplay: IndividualToolCallDisplay = {
callId,
name: cliTool.displayName,
description,
status: ToolCallStatus.Pending,
resultDisplay: undefined,
confirmationDetails: undefined,
};
setPendingHistoryItem((pending) =>
pending?.type === 'tool_group'
? { ...pending, tools: [...pending.tools, toolCallDisplay] }
: null,
);
};
const handleToolCallResponseEvent = (
eventValue: ToolCallResponseEvent['value'],
) => {
const status = eventValue.error
? ToolCallStatus.Error
: ToolCallStatus.Success;
updateFunctionResponseUI(eventValue, status);
};
const handleToolCallConfirmationEvent = (
eventValue: ToolCallConfirmationEvent['value'],
) => {
const confirmationDetails = wireConfirmationSubmission(eventValue);
updateConfirmingFunctionStatusUI(
eventValue.request.callId,
confirmationDetails,
);
};
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
if (pendingHistoryItemRef.current) {
if (pendingHistoryItemRef.current.type === 'tool_group') {
@ -500,6 +349,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
setIsResponding(false);
cancel();
};
const handleErrorEvent = (
@ -521,7 +371,7 @@ export const useGeminiStream = (
userMessageTimestamp: number,
): Promise<StreamProcessingStatus> => {
let geminiMessageBuffer = '';
const toolCallRequests: ToolCallRequestInfo[] = [];
for await (const event of stream) {
if (event.type === ServerGeminiEventType.Content) {
geminiMessageBuffer = handleContentEvent(
@ -530,12 +380,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
handleToolCallRequestEvent(event.value, userMessageTimestamp);
} else if (event.type === ServerGeminiEventType.ToolCallResponse) {
handleToolCallResponseEvent(event.value);
} else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
handleToolCallConfirmationEvent(event.value);
return StreamProcessingStatus.PausedForConfirmation;
toolCallRequests.push(event.value);
} else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp);
return StreamProcessingStatus.UserCancelled;
@ -544,9 +389,18 @@ export const useGeminiStream = (
return StreamProcessingStatus.Error;
}
}
schedule(toolCallRequests);
return StreamProcessingStatus.Completed;
};
const streamingState: StreamingState = isResponding
? StreamingState.Responding
: pendingToolCalls?.tools.some(
(t) => t.status === ToolCallStatus.Confirming,
)
? StreamingState.WaitingForConfirmation
: StreamingState.Idle;
const submitQuery = useCallback(
async (query: PartListUnion) => {
if (isResponding) return;
@ -625,20 +479,15 @@ export const useGeminiStream = (
],
);
const streamingState: StreamingState = isResponding
? StreamingState.Responding
: pendingConfirmations(pendingHistoryItemRef.current)
? StreamingState.WaitingForConfirmation
: StreamingState.Idle;
const pendingHistoryItems = [
pendingHistoryItemRef.current,
pendingToolCalls,
].filter((i) => i !== undefined && i !== null);
return {
streamingState,
submitQuery,
initError,
pendingHistoryItem: pendingHistoryItemRef.current,
pendingHistoryItems,
};
};
const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
item?.type === 'tool_group' &&
item.tools.some((t) => t.status === ToolCallStatus.Confirming);

View File

@ -155,10 +155,9 @@ export class GeminiClient {
signal?: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent> {
let turns = 0;
const availableTools = this.config.getToolRegistry().getAllTools();
while (turns < this.MAX_TURNS) {
turns++;
const turn = new Turn(chat, availableTools);
const turn = new Turn(chat);
const resultStream = turn.run(request, signal);
let seenError = false;
for await (const event of resultStream) {

View File

@ -21,18 +21,6 @@ import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { reportError } from '../utils/errorReporting.js';
import { getErrorMessage } from '../utils/errors.js';
// --- Types for Server Logic ---
// Define a simpler structure for Tool execution results within the server
interface ServerToolExecutionOutcome {
callId: string;
name: string;
args: Record<string, unknown>;
result?: ToolResult;
error?: Error;
confirmationDetails: ToolCallConfirmationDetails | undefined;
}
// Define a structure for tools passed to the server
export interface ServerTool {
name: string;
@ -118,7 +106,6 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context.
export class Turn {
private readonly availableTools: Map<string, ServerTool>;
private pendingToolCalls: Array<{
callId: string;
name: string;
@ -128,11 +115,7 @@ export class Turn {
private confirmationDetails: ToolCallConfirmationDetails[];
private debugResponses: GenerateContentResponse[];
constructor(
private readonly chat: Chat,
availableTools: ServerTool[],
) {
this.availableTools = new Map(availableTools.map((t) => [t.name, t]));
constructor(private readonly chat: Chat) {
this.pendingToolCalls = [];
this.fnResponses = [];
this.confirmationDetails = [];
@ -160,12 +143,9 @@ export class Turn {
yield { type: GeminiEventType.Content, value: text };
}
if (!resp.functionCalls) {
continue;
}
// Handle function calls (requesting tool execution)
for (const fnCall of resp.functionCalls) {
const functionCalls = resp.functionCalls ?? [];
for (const fnCall of functionCalls) {
const event = this.handlePendingFunctionCall(fnCall);
if (event) {
yield event;
@ -184,80 +164,6 @@ export class Turn {
yield { type: GeminiEventType.Error, value: { message: errorMessage } };
return;
}
// Execute pending tool calls
const toolPromises = this.pendingToolCalls.map(
async (pendingToolCall): Promise<ServerToolExecutionOutcome> => {
const tool = this.availableTools.get(pendingToolCall.name);
if (!tool) {
return {
...pendingToolCall,
error: new Error(
`Tool "${pendingToolCall.name}" not found or not provided to Turn.`,
),
confirmationDetails: undefined,
};
}
try {
const confirmationDetails = await tool.shouldConfirmExecute(
pendingToolCall.args,
);
if (confirmationDetails) {
return { ...pendingToolCall, confirmationDetails };
}
const result = await tool.execute(pendingToolCall.args, signal);
return {
...pendingToolCall,
result,
confirmationDetails: undefined,
};
} catch (execError: unknown) {
return {
...pendingToolCall,
error: new Error(
`Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`,
),
confirmationDetails: undefined,
};
}
},
);
const outcomes = await Promise.all(toolPromises);
// Process outcomes and prepare function responses
this.pendingToolCalls = []; // Clear pending calls for this turn
for (const outcome of outcomes) {
if (outcome.confirmationDetails) {
this.confirmationDetails.push(outcome.confirmationDetails);
const serverConfirmationetails: ServerToolCallConfirmationDetails = {
request: {
callId: outcome.callId,
name: outcome.name,
args: outcome.args,
},
details: outcome.confirmationDetails,
};
yield {
type: GeminiEventType.ToolCallConfirmation,
value: serverConfirmationetails,
};
}
const responsePart = this.buildFunctionResponse(outcome);
this.fnResponses.push(responsePart);
const responseInfo: ToolCallResponseInfo = {
callId: outcome.callId,
responsePart,
resultDisplay: outcome.result?.returnDisplay,
error: outcome.error,
};
// If aborted we're already yielding the user cancellations elsewhere.
if (!signal?.aborted) {
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
}
}
}
private handlePendingFunctionCall(
@ -276,30 +182,6 @@ export class Turn {
return { type: GeminiEventType.ToolCallRequest, value };
}
// Builds the Part array expected by the Google GenAI API
private buildFunctionResponse(outcome: ServerToolExecutionOutcome): Part {
const { name, result, error } = outcome;
if (error) {
// Format error for the LLM
const errorMessage = error?.message || String(error);
console.error(`[Server Turn] Error executing tool ${name}:`, error);
return {
functionResponse: {
name,
id: outcome.callId,
response: { error: `Tool execution failed: ${errorMessage}` },
},
};
}
return {
functionResponse: {
name,
id: outcome.callId,
response: { output: result?.llmContent ?? '' },
},
};
}
getConfirmationDetails(): ToolCallConfirmationDetails[] {
return this.confirmationDetails;
}

View File

@ -171,23 +171,28 @@ export interface FileDiff {
fileName: string;
}
export interface ToolCallConfirmationDetails {
export interface ToolCallConfirmationDetailsDefault {
title: string;
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
}
export interface ToolEditConfirmationDetails
extends ToolCallConfirmationDetails {
extends ToolCallConfirmationDetailsDefault {
fileName: string;
fileDiff: string;
}
export interface ToolExecuteConfirmationDetails
extends ToolCallConfirmationDetails {
extends ToolCallConfirmationDetailsDefault {
command: string;
rootCommand: string;
}
export type ToolCallConfirmationDetails =
| ToolCallConfirmationDetailsDefault
| ToolEditConfirmationDetails
| ToolExecuteConfirmationDetails;
export enum ToolConfirmationOutcome {
ProceedOnce,
ProceedAlways,