From f2a8d39f42ae88c1b7a9a5a75854363a53444ca2 Mon Sep 17 00:00:00 2001 From: "N. Taylor Mullen" Date: Sun, 1 Jun 2025 14:16:24 -0700 Subject: [PATCH] refactor: Centralize tool scheduling logic and simplify React hook (#670) --- packages/cli/src/ui/App.tsx | 1 - .../cli/src/ui/hooks/useGeminiStream.test.tsx | 6 +- packages/cli/src/ui/hooks/useGeminiStream.ts | 219 +++--- .../cli/src/ui/hooks/useReactToolScheduler.ts | 301 +++++++++ .../cli/src/ui/hooks/useToolScheduler.test.ts | 18 +- packages/cli/src/ui/hooks/useToolScheduler.ts | 626 ------------------ packages/core/src/core/coreToolScheduler.ts | 520 +++++++++++++++ packages/core/src/index.ts | 3 +- packages/core/src/tools/tools.ts | 2 +- 9 files changed, 938 insertions(+), 758 deletions(-) create mode 100644 packages/cli/src/ui/hooks/useReactToolScheduler.ts delete mode 100644 packages/cli/src/ui/hooks/useToolScheduler.ts create mode 100644 packages/core/src/core/coreToolScheduler.ts diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 2f216db7..baab7fcc 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -200,7 +200,6 @@ export const App = ({ const { streamingState, submitQuery, initError, pendingHistoryItems } = useGeminiStream( addItem, - refreshStatic, setShowHelp, config, setDebugMessage, diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 6959d9a7..44013059 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -9,11 +9,11 @@ import { mergePartListUnions } from './useGeminiStream.js'; import { Part, PartListUnion } from '@google/genai'; // Mock useToolScheduler -vi.mock('./useToolScheduler', async () => { - const actual = await vi.importActual('./useToolScheduler'); +vi.mock('./useReactToolScheduler', async () => { + const actual = await vi.importActual('./useReactToolScheduler'); return { ...actual, // We need mapToDisplay from actual - useToolScheduler: vi.fn(), + useReactToolScheduler: vi.fn(), }; }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 77be6879..35e5a26a 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -16,20 +16,15 @@ import { isNodeError, Config, MessageSenderType, - ServerToolCallConfirmationDetails, - ToolCallResponseInfo, - ToolEditConfirmationDetails, - ToolExecuteConfirmationDetails, - ToolResultDisplay, ToolCallRequestInfo, } from '@gemini-code/core'; -import { type PartListUnion, type Part } from '@google/genai'; +import { type PartListUnion } from '@google/genai'; import { StreamingState, - ToolCallStatus, HistoryItemWithoutId, HistoryItemToolGroup, MessageType, + ToolCallStatus, } from '../types.js'; import { isAtCommand } from '../utils/commandUtils.js'; import { useShellCommandProcessor } from './shellCommandProcessor.js'; @@ -38,7 +33,13 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; -import { useToolScheduler, mapToDisplay } from './useToolScheduler.js'; +import { + useReactToolScheduler, + mapToDisplay as mapTrackedToolCallsToDisplay, + TrackedToolCall, + TrackedCompletedToolCall, + TrackedCancelledToolCall, +} from './useReactToolScheduler.js'; import { GeminiChat } from '@gemini-code/core/src/core/geminiChat.js'; export function mergePartListUnions(list: PartListUnion[]): PartListUnion { @@ -60,12 +61,11 @@ enum StreamProcessingStatus { } /** - * Hook to manage the Gemini stream, handle user input, process commands, - * and interact with the Gemini API and history manager. + * Manages the Gemini stream, including user input, command processing, + * API interaction, and tool call lifecycle. */ export const useGeminiStream = ( addItem: UseHistoryManagerReturn['addItem'], - refreshStatic: () => void, setShowHelp: React.Dispatch>, config: Config, onDebugMessage: (message: string) => void, @@ -82,27 +82,33 @@ export const useGeminiStream = ( const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef(null); const logger = useLogger(); - const [toolCalls, schedule, cancel] = useToolScheduler( - (tools) => { - if (tools.length) { - addItem(mapToDisplay(tools), Date.now()); - const toolResponses = tools - .filter( - (t) => - t.status === 'error' || - t.status === 'cancelled' || - t.status === 'success', - ) - .map((t) => t.response.responseParts); - submitQuery(mergePartListUnions(toolResponses)); + const [ + toolCalls, + scheduleToolCalls, + cancelAllToolCalls, + markToolsAsSubmitted, + ] = useReactToolScheduler( + (completedToolCallsFromScheduler) => { + // This onComplete is called when ALL scheduled tools for a given batch are done. + if (completedToolCallsFromScheduler.length > 0) { + // Add the final state of these tools to the history for display. + // The new useEffect will handle submitting their responses. + addItem( + mapTrackedToolCallsToDisplay( + completedToolCallsFromScheduler as TrackedToolCall[], + ), + Date.now(), + ); } }, config, setPendingHistoryItem, ); - const pendingToolCalls = useMemo( - () => (toolCalls.length ? mapToDisplay(toolCalls) : undefined), + + const pendingToolCallGroupDisplay = useMemo( + () => + toolCalls.length ? mapTrackedToolCallsToDisplay(toolCalls) : undefined, [toolCalls], ); @@ -120,16 +126,16 @@ export const useGeminiStream = ( ); const streamingState = useMemo(() => { - if (toolCalls.some((t) => t.status === 'awaiting_approval')) { + if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) { return StreamingState.WaitingForConfirmation; } if ( isResponding || toolCalls.some( - (t) => - t.status === 'executing' || - t.status === 'scheduled' || - t.status === 'validating', + (tc) => + tc.status === 'executing' || + tc.status === 'scheduled' || + tc.status === 'validating', ) ) { return StreamingState.Responding; @@ -153,7 +159,7 @@ export const useGeminiStream = ( useInput((_input, key) => { if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); - cancel(); + cancelAllToolCalls(); // Also cancel any pending/executing tool calls } }); @@ -194,7 +200,7 @@ export const useGeminiStream = ( name: toolName, args: toolArgs, }; - schedule([toolCallRequest]); // schedule expects an array or single object + scheduleToolCalls([toolCallRequest]); } return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool } @@ -246,7 +252,7 @@ export const useGeminiStream = ( handleSlashCommand, logger, shellModeActive, - schedule, + scheduleToolCalls, ], ); @@ -275,73 +281,6 @@ export const useGeminiStream = ( return { client: currentClient, chat: chatSessionRef.current }; }, [addItem]); - // --- UI Helper Functions (used by event handlers) --- - const updateFunctionResponseUI = ( - toolResponse: ToolCallResponseInfo, - status: ToolCallStatus, - ) => { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === toolResponse.callId - ? { - ...tool, - status, - resultDisplay: toolResponse.resultDisplay, - } - : tool, - ), - } - : item, - ); - }; - - // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure - // or could be a standalone helper if more params are passed. - // TODO: handle file diff result display stuff - function _declineToolExecution( - declineMessage: string, - status: ToolCallStatus, - request: ServerToolCallConfirmationDetails['request'], - originalDetails: ServerToolCallConfirmationDetails['details'], - ) { - let resultDisplay: ToolResultDisplay | undefined; - if ('fileDiff' in originalDetails) { - resultDisplay = { - fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, - fileName: (originalDetails as ToolEditConfirmationDetails).fileName, - }; - } else { - resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; - } - const functionResponse: Part = { - functionResponse: { - id: request.callId, - name: request.name, - response: { error: declineMessage }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responseParts: functionResponse, - resultDisplay, - error: new Error(declineMessage), - }; - const history = chatSessionRef.current?.getHistory(); - if (history) { - history.push({ role: 'model', parts: [functionResponse] }); - } - updateFunctionResponseUI(responseInfo, status); - - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setIsResponding(false); - } - // --- Stream Event Handlers --- const handleContentEvent = useCallback( @@ -425,9 +364,9 @@ export const useGeminiStream = ( userMessageTimestamp, ); setIsResponding(false); - cancel(); + cancelAllToolCalls(); }, - [addItem, pendingHistoryItemRef, setPendingHistoryItem, cancel], + [addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls], ); const handleErrorEvent = useCallback( @@ -462,22 +401,22 @@ export const useGeminiStream = ( toolCallRequests.push(event.value); } else if (event.type === ServerGeminiEventType.UserCancelled) { handleUserCancelledEvent(userMessageTimestamp); - cancel(); return StreamProcessingStatus.UserCancelled; } else if (event.type === ServerGeminiEventType.Error) { handleErrorEvent(event.value, userMessageTimestamp); return StreamProcessingStatus.Error; } } - schedule(toolCallRequests); + if (toolCallRequests.length > 0) { + scheduleToolCalls(toolCallRequests); + } return StreamProcessingStatus.Completed; }, [ handleContentEvent, handleUserCancelledEvent, - cancel, handleErrorEvent, - schedule, + scheduleToolCalls, ], ); @@ -545,21 +484,69 @@ export const useGeminiStream = ( } }, [ - setShowHelp, - addItem, - setInitError, - ensureChatSession, - prepareQueryForGemini, - processGeminiStreamEvents, - setPendingHistoryItem, - pendingHistoryItemRef, streamingState, + setShowHelp, + prepareQueryForGemini, + ensureChatSession, + processGeminiStreamEvents, + pendingHistoryItemRef, + addItem, + setPendingHistoryItem, + setInitError, ], ); + /** + * Automatically submits responses for completed tool calls. + * This effect runs when `toolCalls` or `isResponding` changes. + * It ensures that tool responses are sent back to Gemini only when + * all processing for a given set of tools is finished and Gemini + * is not already generating a response. + */ + useEffect(() => { + if (isResponding) { + return; + } + + const completedAndReadyToSubmitTools = toolCalls.filter( + ( + tc: TrackedToolCall, + ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => { + const isTerminalState = + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled'; + + if (isTerminalState) { + const completedOrCancelledCall = tc as + | TrackedCompletedToolCall + | TrackedCancelledToolCall; + return ( + !completedOrCancelledCall.responseSubmittedToGemini && + completedOrCancelledCall.response?.responseParts !== undefined + ); + } + return false; + }, + ); + + if (completedAndReadyToSubmitTools.length > 0) { + const responsesToSend: PartListUnion[] = + completedAndReadyToSubmitTools.map( + (toolCall) => toolCall.response.responseParts, + ); + const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( + (toolCall) => toolCall.request.callId, + ); + + markToolsAsSubmitted(callIdsToMarkAsSubmitted); + submitQuery(mergePartListUnions(responsesToSend)); + } + }, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]); + const pendingHistoryItems = [ pendingHistoryItemRef.current, - pendingToolCalls, + pendingToolCallGroupDisplay, ].filter((i) => i !== undefined && i !== null); return { diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts new file mode 100644 index 00000000..12333d92 --- /dev/null +++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts @@ -0,0 +1,301 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Config, + ToolCallRequestInfo, + ExecutingToolCall, + ScheduledToolCall, + ValidatingToolCall, + WaitingToolCall, + CompletedToolCall, + CancelledToolCall, + CoreToolScheduler, + OutputUpdateHandler, + AllToolCallsCompleteHandler, + ToolCallsUpdateHandler, + Tool, + ToolCall, + Status as CoreStatus, +} from '@gemini-code/core'; +import { useCallback, useEffect, useState, useRef } from 'react'; +import { + HistoryItemToolGroup, + IndividualToolCallDisplay, + ToolCallStatus, + HistoryItemWithoutId, +} from '../types.js'; + +export type ScheduleFn = ( + request: ToolCallRequestInfo | ToolCallRequestInfo[], +) => void; +export type CancelFn = (reason?: string) => void; +export type MarkToolsAsSubmittedFn = (callIds: string[]) => void; + +export type TrackedScheduledToolCall = ScheduledToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedValidatingToolCall = ValidatingToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedWaitingToolCall = WaitingToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedExecutingToolCall = ExecutingToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedCompletedToolCall = CompletedToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedCancelledToolCall = CancelledToolCall & { + responseSubmittedToGemini?: boolean; +}; + +export type TrackedToolCall = + | TrackedScheduledToolCall + | TrackedValidatingToolCall + | TrackedWaitingToolCall + | TrackedExecutingToolCall + | TrackedCompletedToolCall + | TrackedCancelledToolCall; + +export function useReactToolScheduler( + onComplete: (tools: CompletedToolCall[]) => void, + config: Config, + setPendingHistoryItem: React.Dispatch< + React.SetStateAction + >, +): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] { + const [toolCallsForDisplay, setToolCallsForDisplay] = useState< + TrackedToolCall[] + >([]); + const schedulerRef = useRef(null); + + useEffect(() => { + const outputUpdateHandler: OutputUpdateHandler = ( + toolCallId, + outputChunk, + ) => { + setPendingHistoryItem((prevItem) => { + if (prevItem?.type === 'tool_group') { + return { + ...prevItem, + tools: prevItem.tools.map((toolDisplay) => + toolDisplay.callId === toolCallId && + toolDisplay.status === ToolCallStatus.Executing + ? { ...toolDisplay, resultDisplay: outputChunk } + : toolDisplay, + ), + }; + } + return prevItem; + }); + + setToolCallsForDisplay((prevCalls) => + prevCalls.map((tc) => { + if (tc.request.callId === toolCallId && tc.status === 'executing') { + const executingTc = tc as TrackedExecutingToolCall; + return { ...executingTc, liveOutput: outputChunk }; + } + return tc; + }), + ); + }; + + const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = ( + completedToolCalls, + ) => { + onComplete(completedToolCalls); + }; + + const toolCallsUpdateHandler: ToolCallsUpdateHandler = ( + updatedCoreToolCalls: ToolCall[], + ) => { + setToolCallsForDisplay((prevTrackedCalls) => + updatedCoreToolCalls.map((coreTc) => { + const existingTrackedCall = prevTrackedCalls.find( + (ptc) => ptc.request.callId === coreTc.request.callId, + ); + const newTrackedCall: TrackedToolCall = { + ...coreTc, + responseSubmittedToGemini: + existingTrackedCall?.responseSubmittedToGemini ?? false, + } as TrackedToolCall; + return newTrackedCall; + }), + ); + }; + + schedulerRef.current = new CoreToolScheduler({ + toolRegistry: config.getToolRegistry(), + outputUpdateHandler, + onAllToolCallsComplete: allToolCallsCompleteHandler, + onToolCallsUpdate: toolCallsUpdateHandler, + }); + }, [config, onComplete, setPendingHistoryItem]); + + const schedule: ScheduleFn = useCallback( + async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => { + schedulerRef.current?.schedule(request); + }, + [], + ); + + const cancel: CancelFn = useCallback((reason: string = 'unspecified') => { + schedulerRef.current?.cancelAll(reason); + }, []); + + const markToolsAsSubmitted: MarkToolsAsSubmittedFn = useCallback( + (callIdsToMark: string[]) => { + setToolCallsForDisplay((prevCalls) => + prevCalls.map((tc) => + callIdsToMark.includes(tc.request.callId) + ? { ...tc, responseSubmittedToGemini: true } + : tc, + ), + ); + }, + [], + ); + + return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted]; +} + +/** + * Maps a CoreToolScheduler status to the UI's ToolCallStatus enum. + */ +function mapCoreStatusToDisplayStatus(coreStatus: CoreStatus): ToolCallStatus { + switch (coreStatus) { + case 'validating': + return ToolCallStatus.Executing; + case 'awaiting_approval': + return ToolCallStatus.Confirming; + case 'executing': + return ToolCallStatus.Executing; + case 'success': + return ToolCallStatus.Success; + case 'cancelled': + return ToolCallStatus.Canceled; + case 'error': + return ToolCallStatus.Error; + case 'scheduled': + return ToolCallStatus.Pending; + default: { + const exhaustiveCheck: never = coreStatus; + console.warn(`Unknown core status encountered: ${exhaustiveCheck}`); + return ToolCallStatus.Error; + } + } +} + +/** + * Transforms `TrackedToolCall` objects into `HistoryItemToolGroup` objects for UI display. + */ +export function mapToDisplay( + toolOrTools: TrackedToolCall[] | TrackedToolCall, +): HistoryItemToolGroup { + const toolCalls = Array.isArray(toolOrTools) ? toolOrTools : [toolOrTools]; + + const toolDisplays = toolCalls.map( + (trackedCall): IndividualToolCallDisplay => { + let displayName = trackedCall.request.name; + let description = ''; + let renderOutputAsMarkdown = false; + + const currentToolInstance = + 'tool' in trackedCall && trackedCall.tool + ? (trackedCall as { tool: Tool }).tool + : undefined; + + if (currentToolInstance) { + displayName = currentToolInstance.displayName; + description = currentToolInstance.getDescription( + trackedCall.request.args, + ); + renderOutputAsMarkdown = currentToolInstance.isOutputMarkdown; + } + + if (trackedCall.status === 'error') { + description = ''; + } + + const baseDisplayProperties: Omit< + IndividualToolCallDisplay, + 'status' | 'resultDisplay' | 'confirmationDetails' + > = { + callId: trackedCall.request.callId, + name: displayName, + description, + renderOutputAsMarkdown, + }; + + switch (trackedCall.status) { + case 'success': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: trackedCall.response.resultDisplay, + confirmationDetails: undefined, + }; + case 'error': + return { + ...baseDisplayProperties, + name: currentToolInstance?.displayName ?? trackedCall.request.name, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: trackedCall.response.resultDisplay, + confirmationDetails: undefined, + }; + case 'cancelled': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: trackedCall.response.resultDisplay, + confirmationDetails: undefined, + }; + case 'awaiting_approval': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: undefined, + confirmationDetails: trackedCall.confirmationDetails, + }; + case 'executing': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: + (trackedCall as TrackedExecutingToolCall).liveOutput ?? undefined, + confirmationDetails: undefined, + }; + case 'validating': // Fallthrough + case 'scheduled': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: undefined, + confirmationDetails: undefined, + }; + default: { + const exhaustiveCheck: never = trackedCall; + return { + callId: (exhaustiveCheck as TrackedToolCall).request.callId, + name: 'Unknown Tool', + description: 'Encountered an unknown tool call state.', + status: ToolCallStatus.Error, + resultDisplay: 'Unknown tool call state', + confirmationDetails: undefined, + renderOutputAsMarkdown: false, + }; + } + } + }, + ); + + return { + type: 'tool_group', + tools: toolDisplays, + }; +} diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index ebdfed24..92bff2bc 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -8,12 +8,9 @@ import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest'; import { renderHook, act } from '@testing-library/react'; import { - useToolScheduler, - formatLlmContentForFunctionResponse, + useReactToolScheduler, mapToDisplay, - ToolCall, - Status as ToolCallStatusType, // Renamed to avoid conflict -} from './useToolScheduler.js'; +} from './useReactToolScheduler.js'; import { Part, PartListUnion, @@ -29,6 +26,9 @@ import { ToolCallConfirmationDetails, ToolConfirmationOutcome, ToolCallResponseInfo, + formatLlmContentForFunctionResponse, // Import from core + ToolCall, // Import from core + Status as ToolCallStatusType, // Import from core } from '@gemini-code/core'; import { HistoryItemWithoutId, @@ -205,7 +205,7 @@ describe('formatLlmContentForFunctionResponse', () => { }); }); -describe('useToolScheduler', () => { +describe('useReactToolScheduler', () => { // TODO(ntaylormullen): The following tests are skipped due to difficulties in // reliably testing the asynchronous state updates and interactions with timers. // These tests involve complex sequences of events, including confirmations, @@ -276,7 +276,7 @@ describe('useToolScheduler', () => { const renderScheduler = () => renderHook(() => - useToolScheduler( + useReactToolScheduler( onComplete, mockConfig as unknown as Config, setPendingHistoryItem, @@ -367,7 +367,7 @@ describe('useToolScheduler', () => { request, response: expect.objectContaining({ error: expect.objectContaining({ - message: 'tool nonExistentTool does not exist', + message: 'Tool "nonExistentTool" not found in registry.', }), }), }), @@ -1050,7 +1050,7 @@ describe('mapToDisplay', () => { }, expectedStatus: ToolCallStatus.Error, expectedResultDisplay: 'Execution failed display', - expectedName: baseTool.name, + expectedName: baseTool.displayName, // Changed from baseTool.name expectedDescription: '', }, { diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts deleted file mode 100644 index 9233ebcf..00000000 --- a/packages/cli/src/ui/hooks/useToolScheduler.ts +++ /dev/null @@ -1,626 +0,0 @@ -/** - * @license - * Copyright 2025 Google LLC - * SPDX-License-Identifier: Apache-2.0 - */ - -import { - Config, - ToolCallRequestInfo, - ToolCallResponseInfo, - ToolConfirmationOutcome, - Tool, - ToolCallConfirmationDetails, - ToolResult, -} from '@gemini-code/core'; -import { Part, PartUnion, PartListUnion } from '@google/genai'; -import { useCallback, useEffect, useState } from 'react'; -import { - HistoryItemToolGroup, - IndividualToolCallDisplay, - ToolCallStatus, - HistoryItemWithoutId, -} from '../types.js'; - -type ValidatingToolCall = { - status: 'validating'; - request: ToolCallRequestInfo; - tool: Tool; -}; - -type ScheduledToolCall = { - status: 'scheduled'; - request: ToolCallRequestInfo; - tool: Tool; -}; - -type ErroredToolCall = { - status: 'error'; - request: ToolCallRequestInfo; - response: ToolCallResponseInfo; -}; - -type SuccessfulToolCall = { - status: 'success'; - request: ToolCallRequestInfo; - tool: Tool; - response: ToolCallResponseInfo; -}; - -export type ExecutingToolCall = { - status: 'executing'; - request: ToolCallRequestInfo; - tool: Tool; - liveOutput?: string; -}; - -type CancelledToolCall = { - status: 'cancelled'; - request: ToolCallRequestInfo; - response: ToolCallResponseInfo; - tool: Tool; -}; - -type WaitingToolCall = { - status: 'awaiting_approval'; - request: ToolCallRequestInfo; - tool: Tool; - confirmationDetails: ToolCallConfirmationDetails; -}; - -export type Status = ToolCall['status']; - -export type ToolCall = - | ValidatingToolCall - | ScheduledToolCall - | ErroredToolCall - | SuccessfulToolCall - | ExecutingToolCall - | CancelledToolCall - | WaitingToolCall; - -export type ScheduleFn = ( - request: ToolCallRequestInfo | ToolCallRequestInfo[], -) => void; -export type CancelFn = () => void; -export type CompletedToolCall = - | SuccessfulToolCall - | CancelledToolCall - | ErroredToolCall; - -/** - * Formats a PartListUnion response from a tool into JSON suitable for a Gemini - * FunctionResponse and additional Parts to include after that response. - * - * This is required because FunctionReponse appears to only support JSON - * and not arbitrary parts. Including parts like inlineData or fileData - * directly in a FunctionResponse confuses the model resulting in a failure - * to interpret the multimodal content and context window exceeded errors. - */ - -export function formatLlmContentForFunctionResponse( - llmContent: PartListUnion, -): { - functionResponseJson: Record; - additionalParts: PartUnion[]; -} { - const additionalParts: PartUnion[] = []; - let functionResponseJson: Record; - - if (Array.isArray(llmContent) && llmContent.length === 1) { - // Ensure that length 1 arrays are treated as a single Part. - llmContent = llmContent[0]; - } - - if (typeof llmContent === 'string') { - functionResponseJson = { output: llmContent }; - } else if (Array.isArray(llmContent)) { - functionResponseJson = { status: 'Tool execution succeeded.' }; - additionalParts.push(...llmContent); - } else { - if ( - llmContent.inlineData !== undefined || - llmContent.fileData !== undefined - ) { - // For Parts like inlineData or fileData, use the returnDisplay as the textual output for the functionResponse. - // The actual Part will be added to additionalParts. - functionResponseJson = { - status: `Binary content of type ${llmContent.inlineData?.mimeType || llmContent.fileData?.mimeType || 'unknown'} was processed.`, - }; - additionalParts.push(llmContent); - } else if (llmContent.text !== undefined) { - functionResponseJson = { output: llmContent.text }; - } else { - functionResponseJson = { status: 'Tool execution succeeded.' }; - additionalParts.push(llmContent); - } - } - - return { - functionResponseJson, - additionalParts, - }; -} - -export function useToolScheduler( - onComplete: (tools: CompletedToolCall[]) => void, - config: Config, - setPendingHistoryItem: React.Dispatch< - React.SetStateAction - >, -): [ToolCall[], ScheduleFn, CancelFn] { - const [toolRegistry] = useState(() => config.getToolRegistry()); - const [toolCalls, setToolCalls] = useState([]); - const [abortController, setAbortController] = useState( - () => new AbortController(), - ); - - const isRunning = toolCalls.some( - (t) => t.status === 'executing' || t.status === 'awaiting_approval', - ); - // Note: request array[] typically signal pending tool calls - const schedule = useCallback( - async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => { - if (isRunning) { - throw new Error( - 'Cannot schedule tool calls while other tool calls are running', - ); - } - const requestsToProcess = Array.isArray(request) ? request : [request]; - - // Step 1: Create initial calls with 'validating' status (or 'error' if tool not found) - // and add them to the state immediately to make the UI busy. - const initialNewCalls: ToolCall[] = requestsToProcess.map( - (r): ToolCall => { - const tool = toolRegistry.getTool(r.name); - if (!tool) { - return { - status: 'error', - request: r, - response: toolErrorResponse( - r, - new Error(`tool ${r.name} does not exist`), - ), - }; - } - // Set to 'validating' immediately. This will make streamingState 'Responding'. - return { status: 'validating', request: r, tool }; - }, - ); - setToolCalls((prevCalls) => prevCalls.concat(initialNewCalls)); - - // Step 2: Asynchronously check for confirmation and update status for each new call. - initialNewCalls.forEach(async (initialCall) => { - // If the call was already marked as an error (tool not found), skip further processing. - if (initialCall.status !== 'validating') return; - - const { request: r, tool } = initialCall; - try { - const userApproval = await tool.shouldConfirmExecute( - r.args, - abortController.signal, - ); - if (userApproval) { - // Confirmation is needed. Update status to 'awaiting_approval'. - setToolCalls( - setStatus(r.callId, 'awaiting_approval', { - ...userApproval, - onConfirm: async (outcome) => { - // This onConfirm is triggered by user interaction later. - await userApproval.onConfirm(outcome); - setToolCalls( - outcome === ToolConfirmationOutcome.Cancel - ? setStatus( - r.callId, - 'cancelled', - 'User did not allow tool call', - ) - : // If confirmed, it goes to 'scheduled' to be picked up by the execution effect. - setStatus(r.callId, 'scheduled'), - ); - }, - }), - ); - } else { - // No confirmation needed, move to 'scheduled' for execution. - setToolCalls(setStatus(r.callId, 'scheduled')); - } - } catch (e) { - // Handle errors from tool.shouldConfirmExecute() itself. - setToolCalls( - setStatus( - r.callId, - 'error', - toolErrorResponse( - r, - e instanceof Error ? e : new Error(String(e)), - ), - ), - ); - } - }); - }, - [isRunning, setToolCalls, toolRegistry, abortController.signal], - ); - - const cancel = useCallback( - (reason: string = 'unspecified') => { - abortController.abort(); - setAbortController(new AbortController()); - setToolCalls((tc) => - tc.map((c) => - c.status !== 'error' && c.status !== 'executing' - ? { - ...c, - status: 'cancelled', - response: { - callId: c.request.callId, - responseParts: { - functionResponse: { - id: c.request.callId, - name: c.request.name, - response: { - error: `[Operation Cancelled] Reason: ${reason}`, - }, - }, - }, - resultDisplay: undefined, - error: undefined, - }, - } - : c, - ), - ); - }, - [abortController], - ); - - useEffect(() => { - // effect for executing scheduled tool calls - const allToolsConfirmed = toolCalls.every( - (t) => t.status === 'scheduled' || t.status === 'cancelled', - ); - if (allToolsConfirmed) { - const signal = abortController.signal; - toolCalls - .filter((t) => t.status === 'scheduled') - .forEach((t) => { - const callId = t.request.callId; - setToolCalls(setStatus(t.request.callId, 'executing')); - - const updateOutput = t.tool.canUpdateOutput - ? (output: string) => { - setPendingHistoryItem( - (prevItem: HistoryItemWithoutId | null) => { - if (prevItem?.type === 'tool_group') { - return { - ...prevItem, - tools: prevItem.tools.map( - (toolDisplay: IndividualToolCallDisplay) => - toolDisplay.callId === callId && - toolDisplay.status === ToolCallStatus.Executing - ? { - ...toolDisplay, - resultDisplay: output, - } - : toolDisplay, - ), - }; - } - return prevItem; - }, - ); - // Also update the toolCall itself so that mapToDisplay - // can pick up the live output if the item is not pending - // (e.g. if it's being re-rendered from history) - setToolCalls((prevToolCalls) => - prevToolCalls.map((tc) => - tc.request.callId === callId && tc.status === 'executing' - ? { ...tc, liveOutput: output } - : tc, - ), - ); - } - : undefined; - - t.tool - .execute(t.request.args, signal, updateOutput) - .then((result: ToolResult) => { - if (signal.aborted) { - // TODO(jacobr): avoid stringifying the LLM content. - setToolCalls( - setStatus(callId, 'cancelled', String(result.llmContent)), - ); - return; - } - const { functionResponseJson, additionalParts } = - formatLlmContentForFunctionResponse(result.llmContent); - const functionResponse: Part = { - functionResponse: { - name: t.request.name, - id: callId, - response: functionResponseJson, - }, - }; - const response: ToolCallResponseInfo = { - callId, - responseParts: [functionResponse, ...additionalParts], - resultDisplay: result.returnDisplay, - error: undefined, - }; - setToolCalls(setStatus(callId, 'success', response)); - }) - .catch((e: Error) => - setToolCalls( - setStatus( - callId, - 'error', - toolErrorResponse( - t.request, - e instanceof Error ? e : new Error(String(e)), - ), - ), - ), - ); - }); - } - }, [toolCalls, toolRegistry, abortController.signal, setPendingHistoryItem]); - - useEffect(() => { - const allDone = toolCalls.every( - (t) => - t.status === 'success' || - t.status === 'error' || - t.status === 'cancelled', - ); - if (toolCalls.length && allDone) { - setToolCalls([]); - onComplete(toolCalls); - setAbortController(() => new AbortController()); - } - }, [toolCalls, onComplete]); - - return [toolCalls, schedule, cancel]; -} - -function setStatus( - targetCallId: string, - status: 'success', - response: ToolCallResponseInfo, -): (t: ToolCall[]) => ToolCall[]; -function setStatus( - targetCallId: string, - status: 'awaiting_approval', - confirm: ToolCallConfirmationDetails, -): (t: ToolCall[]) => ToolCall[]; -function setStatus( - targetCallId: string, - status: 'error', - response: ToolCallResponseInfo, -): (t: ToolCall[]) => ToolCall[]; -function setStatus( - targetCallId: string, - status: 'cancelled', - reason: string, -): (t: ToolCall[]) => ToolCall[]; -function setStatus( - targetCallId: string, - status: 'executing' | 'scheduled' | 'validating', -): (t: ToolCall[]) => ToolCall[]; -function setStatus( - targetCallId: string, - status: Status, - auxiliaryData?: unknown, -): (t: ToolCall[]) => ToolCall[] { - return function (tc: ToolCall[]): ToolCall[] { - return tc.map((t) => { - if (t.request.callId !== targetCallId || t.status === 'error') { - return t; - } - switch (status) { - case 'success': { - const next: SuccessfulToolCall = { - ...t, - status: 'success', - response: auxiliaryData as ToolCallResponseInfo, - }; - return next; - } - case 'error': { - const next: ErroredToolCall = { - ...t, - status: 'error', - response: auxiliaryData as ToolCallResponseInfo, - }; - return next; - } - case 'awaiting_approval': { - const next: WaitingToolCall = { - ...t, - status: 'awaiting_approval', - confirmationDetails: auxiliaryData as ToolCallConfirmationDetails, - }; - return next; - } - case 'scheduled': { - const next: ScheduledToolCall = { - ...t, - status: 'scheduled', - }; - return next; - } - case 'cancelled': { - const next: CancelledToolCall = { - ...t, - status: 'cancelled', - response: { - callId: t.request.callId, - responseParts: { - functionResponse: { - id: t.request.callId, - name: t.request.name, - response: { - error: `[Operation Cancelled] Reason: ${auxiliaryData}`, - }, - }, - }, - resultDisplay: undefined, - error: undefined, - }, - }; - return next; - } - case 'validating': { - const next: ValidatingToolCall = { - ...(t as ValidatingToolCall), // Added type assertion for safety - status: 'validating', - }; - return next; - } - case 'executing': { - const next: ExecutingToolCall = { - ...t, - status: 'executing', - }; - return next; - } - default: { - // ensures every case is checked for above - const exhaustiveCheck: never = status; - return exhaustiveCheck; - } - } - }); - }; -} - -const toolErrorResponse = ( - request: ToolCallRequestInfo, - error: Error, -): ToolCallResponseInfo => ({ - callId: request.callId, - error, - responseParts: { - functionResponse: { - id: request.callId, - name: request.name, - response: { error: error.message }, - }, - }, - resultDisplay: error.message, -}); - -function mapStatus(status: Status): ToolCallStatus { - switch (status) { - case 'validating': - return ToolCallStatus.Executing; - case 'awaiting_approval': - return ToolCallStatus.Confirming; - case 'executing': - return ToolCallStatus.Executing; - case 'success': - return ToolCallStatus.Success; - case 'cancelled': - return ToolCallStatus.Canceled; - case 'error': - return ToolCallStatus.Error; - case 'scheduled': - return ToolCallStatus.Pending; - default: { - // ensures every case is checked for above - const exhaustiveCheck: never = status; - return exhaustiveCheck; - } - } -} - -// convenient function for callers to map ToolCall back to a HistoryItem -export function mapToDisplay( - tool: ToolCall[] | ToolCall, -): HistoryItemToolGroup { - const tools = Array.isArray(tool) ? tool : [tool]; - const toolsDisplays = tools.map((t): IndividualToolCallDisplay => { - switch (t.status) { - case 'success': - return { - callId: t.request.callId, - name: t.tool.displayName, - description: t.tool.getDescription(t.request.args), - resultDisplay: t.response.resultDisplay, - status: mapStatus(t.status), - confirmationDetails: undefined, - renderOutputAsMarkdown: t.tool.isOutputMarkdown, - }; - case 'error': - return { - callId: t.request.callId, - name: t.request.name, // Use request.name as tool might be undefined - description: '', // No description available if tool is undefined - resultDisplay: t.response.resultDisplay, - status: mapStatus(t.status), - confirmationDetails: undefined, - renderOutputAsMarkdown: false, - }; - case 'cancelled': - return { - callId: t.request.callId, - name: t.tool.displayName, - description: t.tool.getDescription(t.request.args), - resultDisplay: t.response.resultDisplay, - status: mapStatus(t.status), - confirmationDetails: undefined, - renderOutputAsMarkdown: t.tool.isOutputMarkdown, - }; - case 'awaiting_approval': - return { - callId: t.request.callId, - name: t.tool.displayName, - description: t.tool.getDescription(t.request.args), - resultDisplay: undefined, - status: mapStatus(t.status), - confirmationDetails: t.confirmationDetails, - renderOutputAsMarkdown: t.tool.isOutputMarkdown, - }; - case 'executing': - return { - callId: t.request.callId, - name: t.tool.displayName, - description: t.tool.getDescription(t.request.args), - resultDisplay: t.liveOutput ?? undefined, - status: mapStatus(t.status), - confirmationDetails: undefined, - renderOutputAsMarkdown: t.tool.isOutputMarkdown, - }; - case 'validating': // Add this case - return { - callId: t.request.callId, - name: t.tool.displayName, - description: t.tool.getDescription(t.request.args), - resultDisplay: undefined, - status: mapStatus(t.status), - confirmationDetails: undefined, - renderOutputAsMarkdown: t.tool.isOutputMarkdown, - }; - case 'scheduled': - return { - callId: t.request.callId, - name: t.tool.displayName, - description: t.tool.getDescription(t.request.args), - resultDisplay: undefined, - status: mapStatus(t.status), - confirmationDetails: undefined, - renderOutputAsMarkdown: t.tool.isOutputMarkdown, - }; - default: { - // ensures every case is checked for above - const exhaustiveCheck: never = t; - return exhaustiveCheck; - } - } - }); - const historyItem: HistoryItemToolGroup = { - type: 'tool_group', - tools: toolsDisplays, - }; - return historyItem; -} diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts new file mode 100644 index 00000000..1278d468 --- /dev/null +++ b/packages/core/src/core/coreToolScheduler.ts @@ -0,0 +1,520 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + ToolCallRequestInfo, + ToolCallResponseInfo, + ToolConfirmationOutcome, + Tool, + ToolCallConfirmationDetails, + ToolResult, + ToolRegistry, +} from '../index.js'; +import { Part, PartUnion, PartListUnion } from '@google/genai'; + +export type ValidatingToolCall = { + status: 'validating'; + request: ToolCallRequestInfo; + tool: Tool; +}; + +export type ScheduledToolCall = { + status: 'scheduled'; + request: ToolCallRequestInfo; + tool: Tool; +}; + +export type ErroredToolCall = { + status: 'error'; + request: ToolCallRequestInfo; + response: ToolCallResponseInfo; +}; + +export type SuccessfulToolCall = { + status: 'success'; + request: ToolCallRequestInfo; + tool: Tool; + response: ToolCallResponseInfo; +}; + +export type ExecutingToolCall = { + status: 'executing'; + request: ToolCallRequestInfo; + tool: Tool; + liveOutput?: string; +}; + +export type CancelledToolCall = { + status: 'cancelled'; + request: ToolCallRequestInfo; + response: ToolCallResponseInfo; + tool: Tool; +}; + +export type WaitingToolCall = { + status: 'awaiting_approval'; + request: ToolCallRequestInfo; + tool: Tool; + confirmationDetails: ToolCallConfirmationDetails; +}; + +export type Status = ToolCall['status']; + +export type ToolCall = + | ValidatingToolCall + | ScheduledToolCall + | ErroredToolCall + | SuccessfulToolCall + | ExecutingToolCall + | CancelledToolCall + | WaitingToolCall; + +export type CompletedToolCall = + | SuccessfulToolCall + | CancelledToolCall + | ErroredToolCall; + +export type ConfirmHandler = ( + toolCall: WaitingToolCall, +) => Promise; + +export type OutputUpdateHandler = ( + toolCallId: string, + outputChunk: string, +) => void; + +export type AllToolCallsCompleteHandler = ( + completedToolCalls: CompletedToolCall[], +) => void; + +export type ToolCallsUpdateHandler = (toolCalls: ToolCall[]) => void; + +/** + * Formats tool output for a Gemini FunctionResponse. + */ +export function formatLlmContentForFunctionResponse( + llmContent: PartListUnion, +): { + functionResponseJson: Record; + additionalParts: PartUnion[]; +} { + const additionalParts: PartUnion[] = []; + let functionResponseJson: Record; + + const contentToProcess = + Array.isArray(llmContent) && llmContent.length === 1 + ? llmContent[0] + : llmContent; + + if (typeof contentToProcess === 'string') { + functionResponseJson = { output: contentToProcess }; + } else if (Array.isArray(contentToProcess)) { + functionResponseJson = { + status: 'Tool execution succeeded.', + }; + additionalParts.push(...contentToProcess); + } else if (contentToProcess.inlineData || contentToProcess.fileData) { + const mimeType = + contentToProcess.inlineData?.mimeType || + contentToProcess.fileData?.mimeType || + 'unknown'; + functionResponseJson = { + status: `Binary content of type ${mimeType} was processed.`, + }; + additionalParts.push(contentToProcess); + } else if (contentToProcess.text !== undefined) { + functionResponseJson = { output: contentToProcess.text }; + } else { + functionResponseJson = { status: 'Tool execution succeeded.' }; + additionalParts.push(contentToProcess); + } + + return { + functionResponseJson, + additionalParts, + }; +} + +const createErrorResponse = ( + request: ToolCallRequestInfo, + error: Error, +): ToolCallResponseInfo => ({ + callId: request.callId, + error, + responseParts: { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: error.message }, + }, + }, + resultDisplay: error.message, +}); + +interface CoreToolSchedulerOptions { + toolRegistry: ToolRegistry; + outputUpdateHandler?: OutputUpdateHandler; + onAllToolCallsComplete?: AllToolCallsCompleteHandler; + onToolCallsUpdate?: ToolCallsUpdateHandler; +} + +export class CoreToolScheduler { + private toolRegistry: ToolRegistry; + private toolCalls: ToolCall[] = []; + private abortController: AbortController; + private outputUpdateHandler?: OutputUpdateHandler; + private onAllToolCallsComplete?: AllToolCallsCompleteHandler; + private onToolCallsUpdate?: ToolCallsUpdateHandler; + + constructor(options: CoreToolSchedulerOptions) { + this.toolRegistry = options.toolRegistry; + this.outputUpdateHandler = options.outputUpdateHandler; + this.onAllToolCallsComplete = options.onAllToolCallsComplete; + this.onToolCallsUpdate = options.onToolCallsUpdate; + this.abortController = new AbortController(); + } + + private setStatusInternal( + targetCallId: string, + status: 'success', + response: ToolCallResponseInfo, + ): void; + private setStatusInternal( + targetCallId: string, + status: 'awaiting_approval', + confirmationDetails: ToolCallConfirmationDetails, + ): void; + private setStatusInternal( + targetCallId: string, + status: 'error', + response: ToolCallResponseInfo, + ): void; + private setStatusInternal( + targetCallId: string, + status: 'cancelled', + reason: string, + ): void; + private setStatusInternal( + targetCallId: string, + status: 'executing' | 'scheduled' | 'validating', + ): void; + private setStatusInternal( + targetCallId: string, + newStatus: Status, + auxiliaryData?: unknown, + ): void { + this.toolCalls = this.toolCalls.map((currentCall) => { + if ( + currentCall.request.callId !== targetCallId || + currentCall.status === 'error' + ) { + return currentCall; + } + + const callWithToolContext = currentCall as ToolCall & { tool: Tool }; + + switch (newStatus) { + case 'success': + return { + ...callWithToolContext, + status: 'success', + response: auxiliaryData as ToolCallResponseInfo, + } as SuccessfulToolCall; + case 'error': + return { + request: currentCall.request, + status: 'error', + response: auxiliaryData as ToolCallResponseInfo, + } as ErroredToolCall; + case 'awaiting_approval': + return { + ...callWithToolContext, + status: 'awaiting_approval', + confirmationDetails: auxiliaryData as ToolCallConfirmationDetails, + } as WaitingToolCall; + case 'scheduled': + return { + ...callWithToolContext, + status: 'scheduled', + } as ScheduledToolCall; + case 'cancelled': + return { + ...callWithToolContext, + status: 'cancelled', + response: { + callId: currentCall.request.callId, + responseParts: { + functionResponse: { + id: currentCall.request.callId, + name: currentCall.request.name, + response: { + error: `[Operation Cancelled] Reason: ${auxiliaryData}`, + }, + }, + }, + resultDisplay: undefined, + error: undefined, + }, + } as CancelledToolCall; + case 'validating': + return { + ...(currentCall as ValidatingToolCall), + status: 'validating', + } as ValidatingToolCall; + case 'executing': + return { + ...callWithToolContext, + status: 'executing', + } as ExecutingToolCall; + default: { + const exhaustiveCheck: never = newStatus; + return exhaustiveCheck; + } + } + }); + this.notifyToolCallsUpdate(); + this.checkAndNotifyCompletion(); + } + + private isRunning(): boolean { + return this.toolCalls.some( + (call) => + call.status === 'executing' || call.status === 'awaiting_approval', + ); + } + + async schedule( + request: ToolCallRequestInfo | ToolCallRequestInfo[], + ): Promise { + if (this.isRunning()) { + throw new Error( + 'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).', + ); + } + const requestsToProcess = Array.isArray(request) ? request : [request]; + + const newToolCalls: ToolCall[] = requestsToProcess.map( + (reqInfo): ToolCall => { + const toolInstance = this.toolRegistry.getTool(reqInfo.name); + if (!toolInstance) { + return { + status: 'error', + request: reqInfo, + response: createErrorResponse( + reqInfo, + new Error(`Tool "${reqInfo.name}" not found in registry.`), + ), + }; + } + return { status: 'validating', request: reqInfo, tool: toolInstance }; + }, + ); + + this.toolCalls = this.toolCalls.concat(newToolCalls); + this.notifyToolCallsUpdate(); + + for (const toolCall of newToolCalls) { + if (toolCall.status !== 'validating') { + continue; + } + + const { request: reqInfo, tool: toolInstance } = toolCall; + try { + const confirmationDetails = await toolInstance.shouldConfirmExecute( + reqInfo.args, + this.abortController.signal, + ); + + if (confirmationDetails) { + const originalOnConfirm = confirmationDetails.onConfirm; + const wrappedConfirmationDetails: ToolCallConfirmationDetails = { + ...confirmationDetails, + onConfirm: (outcome: ToolConfirmationOutcome) => + this.handleConfirmationResponse( + reqInfo.callId, + originalOnConfirm, + outcome, + ), + }; + this.setStatusInternal( + reqInfo.callId, + 'awaiting_approval', + wrappedConfirmationDetails, + ); + } else { + this.setStatusInternal(reqInfo.callId, 'scheduled'); + } + } catch (error) { + this.setStatusInternal( + reqInfo.callId, + 'error', + createErrorResponse( + reqInfo, + error instanceof Error ? error : new Error(String(error)), + ), + ); + } + } + this.attemptExecutionOfScheduledCalls(); + this.checkAndNotifyCompletion(); + } + + async handleConfirmationResponse( + callId: string, + originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise, + outcome: ToolConfirmationOutcome, + ): Promise { + const toolCall = this.toolCalls.find( + (c) => c.request.callId === callId && c.status === 'awaiting_approval', + ); + + if (toolCall && toolCall.status === 'awaiting_approval') { + await originalOnConfirm(outcome); + } + + if (outcome === ToolConfirmationOutcome.Cancel) { + this.setStatusInternal( + callId, + 'cancelled', + 'User did not allow tool call', + ); + } else { + this.setStatusInternal(callId, 'scheduled'); + } + this.attemptExecutionOfScheduledCalls(); + } + + private attemptExecutionOfScheduledCalls(): void { + const allCallsFinalOrScheduled = this.toolCalls.every( + (call) => + call.status === 'scheduled' || + call.status === 'cancelled' || + call.status === 'success' || + call.status === 'error', + ); + + if (allCallsFinalOrScheduled) { + const callsToExecute = this.toolCalls.filter( + (call) => call.status === 'scheduled', + ); + + callsToExecute.forEach((toolCall) => { + if (toolCall.status !== 'scheduled') return; + + const scheduledCall = toolCall as ScheduledToolCall; + const { callId, name: toolName } = scheduledCall.request; + this.setStatusInternal(callId, 'executing'); + + const liveOutputCallback = + scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler + ? (outputChunk: string) => { + if (this.outputUpdateHandler) { + this.outputUpdateHandler(callId, outputChunk); + } + this.toolCalls = this.toolCalls.map((tc) => + tc.request.callId === callId && tc.status === 'executing' + ? { ...(tc as ExecutingToolCall), liveOutput: outputChunk } + : tc, + ); + this.notifyToolCallsUpdate(); + } + : undefined; + + scheduledCall.tool + .execute( + scheduledCall.request.args, + this.abortController.signal, + liveOutputCallback, + ) + .then((toolResult: ToolResult) => { + if (this.abortController.signal.aborted) { + this.setStatusInternal( + callId, + 'cancelled', + this.abortController.signal.reason || 'Execution aborted.', + ); + return; + } + + const { functionResponseJson, additionalParts } = + formatLlmContentForFunctionResponse(toolResult.llmContent); + + const functionResponsePart: Part = { + functionResponse: { + name: toolName, + id: callId, + response: functionResponseJson, + }, + }; + + const successResponse: ToolCallResponseInfo = { + callId, + responseParts: [functionResponsePart, ...additionalParts], + resultDisplay: toolResult.returnDisplay, + error: undefined, + }; + this.setStatusInternal(callId, 'success', successResponse); + }) + .catch((executionError: Error) => { + this.setStatusInternal( + callId, + 'error', + createErrorResponse( + scheduledCall.request, + executionError instanceof Error + ? executionError + : new Error(String(executionError)), + ), + ); + }); + }); + } + } + + private checkAndNotifyCompletion(): void { + const allCallsAreTerminal = this.toolCalls.every( + (call) => + call.status === 'success' || + call.status === 'error' || + call.status === 'cancelled', + ); + + if (this.toolCalls.length > 0 && allCallsAreTerminal) { + const completedCalls = [...this.toolCalls] as CompletedToolCall[]; + this.toolCalls = []; + + if (this.onAllToolCallsComplete) { + this.onAllToolCallsComplete(completedCalls); + } + this.abortController = new AbortController(); + this.notifyToolCallsUpdate(); + } + } + + cancelAll(reason: string = 'User initiated cancellation.'): void { + if (!this.abortController.signal.aborted) { + this.abortController.abort(reason); + } + this.abortController = new AbortController(); + + const callsToCancel = [...this.toolCalls]; + callsToCancel.forEach((call) => { + if ( + call.status !== 'error' && + call.status !== 'success' && + call.status !== 'cancelled' + ) { + this.setStatusInternal(call.request.callId, 'cancelled', reason); + } + }); + } + + private notifyToolCallsUpdate(): void { + if (this.onToolCallsUpdate) { + this.onToolCallsUpdate([...this.toolCalls]); + } + } +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 70426d57..f8c42336 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -13,8 +13,7 @@ export * from './core/logger.js'; export * from './core/prompts.js'; export * from './core/turn.js'; export * from './core/geminiRequest.js'; -// Potentially export types from turn.ts if needed externally -// export { GeminiEventType } from './core/turn.js'; // Example +export * from './core/coreToolScheduler.js'; // Export utilities export * from './utils/paths.js'; diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index a2e7fa06..1b932229 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -218,7 +218,7 @@ export interface ToolMcpConfirmationDetails { serverName: string; toolName: string; toolDisplayName: string; - onConfirm: (outcome: ToolConfirmationOutcome) => Promise | void; + onConfirm: (outcome: ToolConfirmationOutcome) => Promise; } export type ToolCallConfirmationDetails =