diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 42613530..74c1ea5d 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -134,7 +134,7 @@ export const App = ({ cliVersion, ); - const { streamingState, submitQuery, initError, pendingHistoryItem } = + const { streamingState, submitQuery, initError, pendingHistoryItems } = useGeminiStream( addItem, refreshStatic, @@ -209,7 +209,7 @@ export const App = ({ }, [terminalHeight, footerHeight]); useEffect(() => { - if (!pendingHistoryItem) { + if (!pendingHistoryItems.length) { return; } @@ -223,7 +223,7 @@ export const App = ({ if (pendingItemDimensions.height > availableTerminalHeight) { setStaticNeedsRefresh(true); } - }, [pendingHistoryItem, availableTerminalHeight, streamingState]); + }, [pendingHistoryItems.length, availableTerminalHeight, streamingState]); useEffect(() => { if (streamingState === StreamingState.Idle && staticNeedsRefresh) { @@ -264,17 +264,18 @@ export const App = ({ > {(item) => item} - {pendingHistoryItem && ( - + + {pendingHistoryItems.map((item, i) => ( - - )} + ))} + {showHelp && } diff --git a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx index d0ad1c5f..4b2c7dfe 100644 --- a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React from 'react'; +import React, { useMemo } from 'react'; import { Box } from 'ink'; import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js'; import { ToolMessage } from './ToolMessage.js'; @@ -19,7 +19,6 @@ interface ToolGroupMessageProps { // Main component renders the border and maps the tools using ToolMessage export const ToolGroupMessage: React.FC = ({ - groupId, toolCalls, availableTerminalHeight, }) => { @@ -30,9 +29,13 @@ export const ToolGroupMessage: React.FC = ({ const staticHeight = /* border */ 2 + /* marginBottom */ 1; + const toolAwaitingApproval = useMemo( + () => toolCalls.find((tc) => tc.status === ToolCallStatus.Confirming), + [toolCalls], + ); + return ( = ({ marginBottom={1} > {toolCalls.map((tool) => ( - + = ({ availableTerminalHeight={availableTerminalHeight - staticHeight} /> {tool.status === ToolCallStatus.Confirming && + tool.callId === toolAwaitingApproval?.callId && tool.confirmationDetails && ( boolean, shellModeActive: boolean, ) => { - const toolRegistry = config.getToolRegistry(); const [initError, setInitError] = useState(null); const abortControllerRef = useRef(null); const chatSessionRef = useRef(null); @@ -74,6 +68,25 @@ export const useGeminiStream = ( const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef(null); const logger = useLogger(); + const [toolCalls, schedule, cancel] = useToolScheduler((tools) => { + if (tools.length) { + addItem(mapToDisplay(tools), Date.now()); + submitQuery( + tools + .filter( + (t) => + t.status === 'error' || + t.status === 'cancelled' || + t.status === 'success', + ) + .map((t) => t.response.responsePart), + ); + } + }, config); + const pendingToolCalls = useMemo( + () => (toolCalls.length ? mapToDisplay(toolCalls) : undefined), + [toolCalls], + ); const onExec = useCallback(async (done: Promise) => { setIsResponding(true); @@ -104,6 +117,7 @@ export const useGeminiStream = ( useInput((_input, key) => { if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); + cancel(); } }); @@ -215,157 +229,48 @@ export const useGeminiStream = ( ); }; - const updateConfirmingFunctionStatusUI = ( - callId: string, - confirmationDetails: ToolCallConfirmationDetails | undefined, - ) => { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails, - } - : tool, - ), - } - : item, - ); - }; - - const wireConfirmationSubmission = ( - confirmationDetails: ServerToolCallConfirmationDetails, - ): ToolCallConfirmationDetails => { - const originalConfirmationDetails = confirmationDetails.details; - const request = confirmationDetails.request; - const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => { - originalConfirmationDetails.onConfirm(outcome); - if (pendingHistoryItemRef?.current?.type === 'tool_group') { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === request.callId - ? { - ...tool, - confirmationDetails: undefined, - status: ToolCallStatus.Executing, - } - : tool, - ), - } - : item, - ); - refreshStatic(); - } - - if (outcome === ToolConfirmationOutcome.Cancel) { - declineToolExecution( - 'User rejected function call.', - ToolCallStatus.Error, - request, - originalConfirmationDetails, - ); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - try { - abortControllerRef.current = new AbortController(); - const result = await tool.execute( - request.args, - abortControllerRef.current.signal, - ); - if (abortControllerRef.current.signal.aborted) { - declineToolExecution( - partListUnionToString(result.llmContent), - ToolCallStatus.Canceled, - request, - originalConfirmationDetails, - ); - return; - } - - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setIsResponding(false); - await submitQuery(functionResponse); // Recursive call - } finally { - if (streamingState !== StreamingState.WaitingForConfirmation) { - abortControllerRef.current = null; - } - } - } - }; - - // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure - // or could be a standalone helper if more params are passed. - function declineToolExecution( - declineMessage: string, - status: ToolCallStatus, - request: ServerToolCallConfirmationDetails['request'], - originalDetails: ServerToolCallConfirmationDetails['details'], - ) { - let resultDisplay: ToolResultDisplay | undefined; - if ('fileDiff' in originalDetails) { - resultDisplay = { - fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, - fileName: (originalDetails as ToolEditConfirmationDetails).fileName, - }; - } else { - resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; - } - const functionResponse: Part = { - functionResponse: { - id: request.callId, - name: request.name, - response: { error: declineMessage }, - }, + // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure + // or could be a standalone helper if more params are passed. + // TODO: handle file diff result display stuff + function _declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + request: ServerToolCallConfirmationDetails['request'], + originalDetails: ServerToolCallConfirmationDetails['details'], + ) { + let resultDisplay: ToolResultDisplay | undefined; + if ('fileDiff' in originalDetails) { + resultDisplay = { + fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, + fileName: (originalDetails as ToolEditConfirmationDetails).fileName, }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay, - error: new Error(declineMessage), - }; - const history = chatSessionRef.current?.getHistory(); - if (history) { - history.push({ role: 'model', parts: [functionResponse] }); - } - updateFunctionResponseUI(responseInfo, status); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setIsResponding(false); + } else { + resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; } - - return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm }; - }; + const functionResponse: Part = { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: declineMessage }, + }, + }; + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay, + error: new Error(declineMessage), + }; + const history = chatSessionRef.current?.getHistory(); + if (history) { + history.push({ role: 'model', parts: [functionResponse] }); + } + updateFunctionResponseUI(responseInfo, status); + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, Date.now()); + setPendingHistoryItem(null); + } + setIsResponding(false); + } // --- Stream Event Handlers --- const handleContentEvent = ( @@ -419,62 +324,6 @@ export const useGeminiStream = ( return newGeminiMessageBuffer; }; - const handleToolCallRequestEvent = ( - eventValue: ToolCallRequestEvent['value'], - userMessageTimestamp: number, - ) => { - const { callId, name, args } = eventValue; - const cliTool = toolRegistry.getTool(name); - if (!cliTool) { - console.error(`CLI Tool "${name}" not found!`); - return; // Skip this event if tool is not found - } - if (pendingHistoryItemRef.current?.type !== 'tool_group') { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem({ type: 'tool_group', tools: [] }); - } - let description: string; - try { - description = cliTool.getDescription(args); - } catch (e) { - description = `Error: Unable to get description: ${getErrorMessage(e)}`; - } - const toolCallDisplay: IndividualToolCallDisplay = { - callId, - name: cliTool.displayName, - description, - status: ToolCallStatus.Pending, - resultDisplay: undefined, - confirmationDetails: undefined, - }; - setPendingHistoryItem((pending) => - pending?.type === 'tool_group' - ? { ...pending, tools: [...pending.tools, toolCallDisplay] } - : null, - ); - }; - - const handleToolCallResponseEvent = ( - eventValue: ToolCallResponseEvent['value'], - ) => { - const status = eventValue.error - ? ToolCallStatus.Error - : ToolCallStatus.Success; - updateFunctionResponseUI(eventValue, status); - }; - - const handleToolCallConfirmationEvent = ( - eventValue: ToolCallConfirmationEvent['value'], - ) => { - const confirmationDetails = wireConfirmationSubmission(eventValue); - updateConfirmingFunctionStatusUI( - eventValue.request.callId, - confirmationDetails, - ); - }; - const handleUserCancelledEvent = (userMessageTimestamp: number) => { if (pendingHistoryItemRef.current) { if (pendingHistoryItemRef.current.type === 'tool_group') { @@ -500,6 +349,7 @@ export const useGeminiStream = ( userMessageTimestamp, ); setIsResponding(false); + cancel(); }; const handleErrorEvent = ( @@ -521,7 +371,7 @@ export const useGeminiStream = ( userMessageTimestamp: number, ): Promise => { let geminiMessageBuffer = ''; - + const toolCallRequests: ToolCallRequestInfo[] = []; for await (const event of stream) { if (event.type === ServerGeminiEventType.Content) { geminiMessageBuffer = handleContentEvent( @@ -530,12 +380,7 @@ export const useGeminiStream = ( userMessageTimestamp, ); } else if (event.type === ServerGeminiEventType.ToolCallRequest) { - handleToolCallRequestEvent(event.value, userMessageTimestamp); - } else if (event.type === ServerGeminiEventType.ToolCallResponse) { - handleToolCallResponseEvent(event.value); - } else if (event.type === ServerGeminiEventType.ToolCallConfirmation) { - handleToolCallConfirmationEvent(event.value); - return StreamProcessingStatus.PausedForConfirmation; + toolCallRequests.push(event.value); } else if (event.type === ServerGeminiEventType.UserCancelled) { handleUserCancelledEvent(userMessageTimestamp); return StreamProcessingStatus.UserCancelled; @@ -544,9 +389,18 @@ export const useGeminiStream = ( return StreamProcessingStatus.Error; } } + schedule(toolCallRequests); return StreamProcessingStatus.Completed; }; + const streamingState: StreamingState = isResponding + ? StreamingState.Responding + : pendingToolCalls?.tools.some( + (t) => t.status === ToolCallStatus.Confirming, + ) + ? StreamingState.WaitingForConfirmation + : StreamingState.Idle; + const submitQuery = useCallback( async (query: PartListUnion) => { if (isResponding) return; @@ -625,20 +479,15 @@ export const useGeminiStream = ( ], ); - const streamingState: StreamingState = isResponding - ? StreamingState.Responding - : pendingConfirmations(pendingHistoryItemRef.current) - ? StreamingState.WaitingForConfirmation - : StreamingState.Idle; + const pendingHistoryItems = [ + pendingHistoryItemRef.current, + pendingToolCalls, + ].filter((i) => i !== undefined && i !== null); return { streamingState, submitQuery, initError, - pendingHistoryItem: pendingHistoryItemRef.current, + pendingHistoryItems, }; }; - -const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean => - item?.type === 'tool_group' && - item.tools.some((t) => t.status === ToolCallStatus.Confirming); diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 51895802..489e2a0b 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -155,10 +155,9 @@ export class GeminiClient { signal?: AbortSignal, ): AsyncGenerator { let turns = 0; - const availableTools = this.config.getToolRegistry().getAllTools(); while (turns < this.MAX_TURNS) { turns++; - const turn = new Turn(chat, availableTools); + const turn = new Turn(chat); const resultStream = turn.run(request, signal); let seenError = false; for await (const event of resultStream) { diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index dd2a08ce..7b2a96f9 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -21,18 +21,6 @@ import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { reportError } from '../utils/errorReporting.js'; import { getErrorMessage } from '../utils/errors.js'; -// --- Types for Server Logic --- - -// Define a simpler structure for Tool execution results within the server -interface ServerToolExecutionOutcome { - callId: string; - name: string; - args: Record; - result?: ToolResult; - error?: Error; - confirmationDetails: ToolCallConfirmationDetails | undefined; -} - // Define a structure for tools passed to the server export interface ServerTool { name: string; @@ -118,7 +106,6 @@ export type ServerGeminiStreamEvent = // A turn manages the agentic loop turn within the server context. export class Turn { - private readonly availableTools: Map; private pendingToolCalls: Array<{ callId: string; name: string; @@ -128,11 +115,7 @@ export class Turn { private confirmationDetails: ToolCallConfirmationDetails[]; private debugResponses: GenerateContentResponse[]; - constructor( - private readonly chat: Chat, - availableTools: ServerTool[], - ) { - this.availableTools = new Map(availableTools.map((t) => [t.name, t])); + constructor(private readonly chat: Chat) { this.pendingToolCalls = []; this.fnResponses = []; this.confirmationDetails = []; @@ -160,12 +143,9 @@ export class Turn { yield { type: GeminiEventType.Content, value: text }; } - if (!resp.functionCalls) { - continue; - } - // Handle function calls (requesting tool execution) - for (const fnCall of resp.functionCalls) { + const functionCalls = resp.functionCalls ?? []; + for (const fnCall of functionCalls) { const event = this.handlePendingFunctionCall(fnCall); if (event) { yield event; @@ -184,80 +164,6 @@ export class Turn { yield { type: GeminiEventType.Error, value: { message: errorMessage } }; return; } - - // Execute pending tool calls - const toolPromises = this.pendingToolCalls.map( - async (pendingToolCall): Promise => { - const tool = this.availableTools.get(pendingToolCall.name); - if (!tool) { - return { - ...pendingToolCall, - error: new Error( - `Tool "${pendingToolCall.name}" not found or not provided to Turn.`, - ), - confirmationDetails: undefined, - }; - } - - try { - const confirmationDetails = await tool.shouldConfirmExecute( - pendingToolCall.args, - ); - if (confirmationDetails) { - return { ...pendingToolCall, confirmationDetails }; - } - const result = await tool.execute(pendingToolCall.args, signal); - return { - ...pendingToolCall, - result, - confirmationDetails: undefined, - }; - } catch (execError: unknown) { - return { - ...pendingToolCall, - error: new Error( - `Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`, - ), - confirmationDetails: undefined, - }; - } - }, - ); - const outcomes = await Promise.all(toolPromises); - - // Process outcomes and prepare function responses - this.pendingToolCalls = []; // Clear pending calls for this turn - - for (const outcome of outcomes) { - if (outcome.confirmationDetails) { - this.confirmationDetails.push(outcome.confirmationDetails); - const serverConfirmationetails: ServerToolCallConfirmationDetails = { - request: { - callId: outcome.callId, - name: outcome.name, - args: outcome.args, - }, - details: outcome.confirmationDetails, - }; - yield { - type: GeminiEventType.ToolCallConfirmation, - value: serverConfirmationetails, - }; - } - const responsePart = this.buildFunctionResponse(outcome); - this.fnResponses.push(responsePart); - const responseInfo: ToolCallResponseInfo = { - callId: outcome.callId, - responsePart, - resultDisplay: outcome.result?.returnDisplay, - error: outcome.error, - }; - - // If aborted we're already yielding the user cancellations elsewhere. - if (!signal?.aborted) { - yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; - } - } } private handlePendingFunctionCall( @@ -276,30 +182,6 @@ export class Turn { return { type: GeminiEventType.ToolCallRequest, value }; } - // Builds the Part array expected by the Google GenAI API - private buildFunctionResponse(outcome: ServerToolExecutionOutcome): Part { - const { name, result, error } = outcome; - if (error) { - // Format error for the LLM - const errorMessage = error?.message || String(error); - console.error(`[Server Turn] Error executing tool ${name}:`, error); - return { - functionResponse: { - name, - id: outcome.callId, - response: { error: `Tool execution failed: ${errorMessage}` }, - }, - }; - } - return { - functionResponse: { - name, - id: outcome.callId, - response: { output: result?.llmContent ?? '' }, - }, - }; - } - getConfirmationDetails(): ToolCallConfirmationDetails[] { return this.confirmationDetails; } diff --git a/packages/server/src/tools/tools.ts b/packages/server/src/tools/tools.ts index 329010bc..58209166 100644 --- a/packages/server/src/tools/tools.ts +++ b/packages/server/src/tools/tools.ts @@ -171,23 +171,28 @@ export interface FileDiff { fileName: string; } -export interface ToolCallConfirmationDetails { +export interface ToolCallConfirmationDetailsDefault { title: string; onConfirm: (outcome: ToolConfirmationOutcome) => Promise; } export interface ToolEditConfirmationDetails - extends ToolCallConfirmationDetails { + extends ToolCallConfirmationDetailsDefault { fileName: string; fileDiff: string; } export interface ToolExecuteConfirmationDetails - extends ToolCallConfirmationDetails { + extends ToolCallConfirmationDetailsDefault { command: string; rootCommand: string; } +export type ToolCallConfirmationDetails = + | ToolCallConfirmationDetailsDefault + | ToolEditConfirmationDetails + | ToolExecuteConfirmationDetails; + export enum ToolConfirmationOutcome { ProceedOnce, ProceedAlways,