diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts new file mode 100644 index 00000000..2cb27141 --- /dev/null +++ b/packages/cli/src/ui/hooks/useToolScheduler.ts @@ -0,0 +1,464 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Config, + ToolCallRequestInfo, + ToolCallResponseInfo, + ToolConfirmationOutcome, + Tool, +} from '@gemini-code/server'; +import { Part } from '@google/genai'; +import { useCallback, useEffect, useState } from 'react'; +import { + HistoryItemToolGroup, + IndividualToolCallDisplay, + ToolCallStatus, +} from '../types.js'; + +type ScheduledToolCall = { + status: 'scheduled'; + request: ToolCallRequestInfo; + tool: Tool; +}; + +type ErroredToolCall = { + status: 'error'; + request: ToolCallRequestInfo; + response: ToolCallResponseInfo; +}; + +type SuccessfulToolCall = { + status: 'success'; + request: ToolCallRequestInfo; + tool: Tool; + response: ToolCallResponseInfo; +}; + +type ExecutingToolCall = { + status: 'executing'; + request: ToolCallRequestInfo; + tool: Tool; +}; + +type CancelledToolCall = { + status: 'cancelled'; + request: ToolCallRequestInfo; + response: ToolCallResponseInfo; + tool: Tool; +}; + +type WaitingToolCall = { + status: 'awaiting_approval'; + request: ToolCallRequestInfo; + tool: Tool; + confirm: (outcome: ToolConfirmationOutcome) => Promise; +}; + +export type Status = ToolCall['status']; + +export type ToolCall = + | ScheduledToolCall + | ErroredToolCall + | SuccessfulToolCall + | ExecutingToolCall + | CancelledToolCall + | WaitingToolCall; + +export type ScheduleFn = ( + request: ToolCallRequestInfo | ToolCallRequestInfo[], +) => void; +export type CancelFn = () => void; +export type CompletedToolCall = + | SuccessfulToolCall + | CancelledToolCall + | ErroredToolCall; + +export function useToolScheduler( + onComplete: (tools: CompletedToolCall[]) => void, + config: Config, +): [ToolCall[], ScheduleFn, CancelFn] { + const [toolRegistry] = useState(() => config.getToolRegistry()); + const [toolCalls, setToolCalls] = useState([]); + const [abortController, setAbortController] = useState( + () => new AbortController(), + ); + + const isRunning = toolCalls.some( + (t) => t.status === 'executing' || t.status === 'awaiting_approval', + ); + // Note: request array[] typically signal pending tool calls + const schedule = useCallback( + async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => { + if (isRunning) { + throw new Error( + 'Cannot schedule tool calls while other tool calls are running', + ); + } + const requests = Array.isArray(request) ? request : [request]; + const newCalls: ToolCall[] = await Promise.all( + requests.map(async (r): Promise => { + const tool = toolRegistry.getTool(r.name); + if (!tool) { + return { + status: 'error', + request: r, + response: toolErrorResponse( + r, + new Error(`tool ${r.name} does not exist`), + ), + }; + } + + const userApproval = await tool.shouldConfirmExecute(r.args); + if (userApproval) { + return { + status: 'awaiting_approval', + request: r, + tool, + confirm: async (outcome) => { + await userApproval.onConfirm(outcome); + setToolCalls( + outcome === ToolConfirmationOutcome.Cancel + ? setStatus( + r.callId, + 'cancelled', + 'User did not allow tool call', + ) + : setStatus(r.callId, 'scheduled'), + ); + }, + }; + } + + return { + status: 'scheduled', + request: r, + tool, + }; + }), + ); + setToolCalls((t) => t.concat(newCalls)); + }, + [isRunning, setToolCalls, toolRegistry], + ); + + const cancel = useCallback( + (reason: string = 'unspecified') => { + abortController.abort(); + setAbortController(new AbortController()); + setToolCalls((tc) => + tc.map((c) => + c.status !== 'error' + ? { + ...c, + status: 'cancelled', + response: { + callId: c.request.callId, + responsePart: { + functionResponse: { + id: c.request.callId, + name: c.request.name, + response: { + error: `[Operation Cancelled] Reason: ${reason}`, + }, + }, + }, + resultDisplay: undefined, + error: undefined, + }, + } + : c, + ), + ); + }, + [abortController], + ); + + useEffect(() => { + // effect for executing scheduled tool calls + const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled'); + const awaitingConfirmation = toolCalls.some( + (t) => t.status === 'awaiting_approval', + ); + if (!awaitingConfirmation && scheduledCalls.length) { + scheduledCalls.forEach(async (c) => { + const callId = c.request.callId; + try { + setToolCalls(setStatus(c.request.callId, 'executing')); + const result = await c.tool.execute( + c.request.args, + abortController.signal, + ); + const functionResponse: Part = { + functionResponse: { + name: c.request.name, + id: callId, + response: { output: result.llmContent }, + }, + }; + const response: ToolCallResponseInfo = { + callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + setToolCalls(setStatus(callId, 'success', response)); + } catch (e: unknown) { + setToolCalls( + setStatus( + callId, + 'error', + toolErrorResponse( + c.request, + e instanceof Error ? e : new Error(String(e)), + ), + ), + ); + } + }); + } + }, [toolCalls, toolRegistry, abortController.signal]); + + useEffect(() => { + const completedTools = toolCalls.filter( + (t) => + t.status === 'success' || + t.status === 'error' || + t.status === 'cancelled', + ); + const allDone = completedTools.length === toolCalls.length; + if (toolCalls.length && allDone) { + onComplete(completedTools); + setToolCalls([]); + setAbortController(new AbortController()); + } + }, [toolCalls, onComplete]); + + return [toolCalls, schedule, cancel]; +} + +function setStatus( + targetCallId: string, + status: 'success', + response: ToolCallResponseInfo, +): (t: ToolCall[]) => ToolCall[]; +function setStatus( + targetCallId: string, + status: 'awaiting_approval', + confirm: (t: ToolConfirmationOutcome) => Promise, +): (t: ToolCall[]) => ToolCall[]; +function setStatus( + targetCallId: string, + status: 'error', + response: ToolCallResponseInfo, +): (t: ToolCall[]) => ToolCall[]; +function setStatus( + targetCallId: string, + status: 'cancelled', + reason: string, +): (t: ToolCall[]) => ToolCall[]; +function setStatus( + targetCallId: string, + status: 'executing' | 'scheduled', +): (t: ToolCall[]) => ToolCall[]; +function setStatus( + targetCallId: string, + status: Status, + auxiliaryData?: unknown, +): (t: ToolCall[]) => ToolCall[] { + return function (tc: ToolCall[]): ToolCall[] { + return tc.map((t) => { + if (t.request.callId !== targetCallId || t.status === 'error') { + return t; + } + switch (status) { + case 'success': { + const next: SuccessfulToolCall = { + ...t, + status: 'success', + response: auxiliaryData as ToolCallResponseInfo, + }; + return next; + } + case 'error': { + const next: ErroredToolCall = { + ...t, + status: 'error', + response: auxiliaryData as ToolCallResponseInfo, + }; + return next; + } + case 'awaiting_approval': { + const next: WaitingToolCall = { + ...t, + status: 'awaiting_approval', + confirm: auxiliaryData as ( + o: ToolConfirmationOutcome, + ) => Promise, + }; + return next; + } + case 'scheduled': { + const next: ScheduledToolCall = { + ...t, + status: 'scheduled', + }; + return next; + } + case 'cancelled': { + const next: CancelledToolCall = { + ...t, + status: 'cancelled', + response: { + callId: t.request.callId, + responsePart: { + functionResponse: { + id: t.request.callId, + name: t.request.name, + response: { + error: `[Operation Cancelled] Reason: ${auxiliaryData}`, + }, + }, + }, + resultDisplay: undefined, + error: undefined, + }, + }; + return next; + } + case 'executing': { + const next: ExecutingToolCall = { + ...t, + status: 'executing', + }; + return next; + } + default: { + // ensures every case is checked for above + const exhaustiveCheck: never = status; + return exhaustiveCheck; + } + } + }); + }; +} + +const toolErrorResponse = ( + request: ToolCallRequestInfo, + error: Error, +): ToolCallResponseInfo => ({ + callId: request.callId, + error, + responsePart: { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: error.message }, + }, + }, + resultDisplay: error.message, +}); + +function mapStatus(status: Status): ToolCallStatus { + switch (status) { + case 'awaiting_approval': + return ToolCallStatus.Confirming; + case 'executing': + return ToolCallStatus.Executing; + case 'success': + return ToolCallStatus.Success; + case 'cancelled': + return ToolCallStatus.Canceled; + case 'error': + return ToolCallStatus.Error; + case 'scheduled': + return ToolCallStatus.Pending; + default: { + // ensures every case is checked for above + const exhaustiveCheck: never = status; + return exhaustiveCheck; + } + } +} + +// convenient function for callers to map ToolCall back to a HistoryItem +export function mapToDisplay( + tool: ToolCall[] | ToolCall, +): HistoryItemToolGroup { + const tools = Array.isArray(tool) ? tool : [tool]; + const toolsDisplays = tools.map((t): IndividualToolCallDisplay => { + switch (t.status) { + case 'success': + return { + callId: t.request.callId, + name: t.tool.displayName, + description: t.tool.getDescription(t.request.args), + resultDisplay: t.response.resultDisplay, + status: mapStatus(t.status), + confirmationDetails: undefined, + }; + case 'error': + return { + callId: t.request.callId, + name: t.request.name, + description: '', + resultDisplay: t.response.resultDisplay, + status: mapStatus(t.status), + confirmationDetails: undefined, + }; + case 'cancelled': + return { + callId: t.request.callId, + name: t.tool.displayName, + description: t.tool.getDescription(t.request.args), + resultDisplay: t.response.resultDisplay, + status: mapStatus(t.status), + confirmationDetails: undefined, + }; + case 'awaiting_approval': + return { + callId: t.request.callId, + name: t.tool.displayName, + description: t.tool.getDescription(t.request.args), + resultDisplay: undefined, + status: mapStatus(t.status), + confirmationDetails: { + title: t.request.name, + onConfirm: t.confirm, + }, + }; + case 'executing': + return { + callId: t.request.callId, + name: t.tool.displayName, + description: t.tool.getDescription(t.request.args), + resultDisplay: undefined, + status: mapStatus(t.status), + confirmationDetails: undefined, + }; + case 'scheduled': + return { + callId: t.request.callId, + name: t.tool.displayName, + description: t.tool.getDescription(t.request.args), + resultDisplay: undefined, + status: mapStatus(t.status), + confirmationDetails: undefined, + }; + default: { + // ensures every case is checked for above + const exhaustiveCheck: never = t; + return exhaustiveCheck; + } + } + }); + const historyItem: HistoryItemToolGroup = { + type: 'tool_group', + tools: toolsDisplays, + }; + return historyItem; +}