diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 2b18f0a1..035f3e85 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -9,6 +9,12 @@ import { useInput } from 'ink'; import { GeminiClient, GeminiEventType as ServerGeminiEventType, + ServerGeminiStreamEvent as GeminiEvent, + ServerGeminiContentEvent as ContentEvent, + ServerGeminiToolCallRequestEvent as ToolCallRequestEvent, + ServerGeminiToolCallResponseEvent as ToolCallResponseEvent, + ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent, + ServerGeminiErrorEvent as ErrorEvent, getErrorMessage, isNodeError, Config, @@ -26,6 +32,7 @@ import { IndividualToolCallDisplay, ToolCallStatus, HistoryItemWithoutId, + HistoryItemToolGroup, MessageType, } from '../types.js'; import { isAtCommand } from '../utils/commandUtils.js'; @@ -35,6 +42,17 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; +enum StreamProcessingStatus { + Completed, + PausedForConfirmation, + UserCancelled, + Error, +} + +/** + * Hook to manage the Gemini stream, handle user input, process commands, + * and interact with the Gemini API and history manager. + */ export const useGeminiStream = ( addItem: UseHistoryManagerReturn['addItem'], _clearItems: UseHistoryManagerReturn['clearItems'], @@ -82,240 +100,490 @@ export const useGeminiStream = ( } }); + const prepareQueryForGemini = async ( + query: PartListUnion, + userMessageTimestamp: number, + signal: AbortSignal, + ): Promise<{ queryToSend: PartListUnion | null; shouldProceed: boolean }> => { + if (typeof query === 'string' && query.trim().length === 0) { + return { queryToSend: null, shouldProceed: false }; + } + + let localQueryToSendToGemini: PartListUnion | null = null; + + if (typeof query === 'string') { + const trimmedQuery = query.trim(); + onDebugMessage(`User query: '${trimmedQuery}'`); + + // Handle UI-only commands first + if (handleSlashCommand(trimmedQuery)) { + return { queryToSend: null, shouldProceed: false }; + } + if (handleShellCommand(trimmedQuery)) { + return { queryToSend: null, shouldProceed: false }; + } + + // Handle @-commands (which might involve tool calls) + if (isAtCommand(trimmedQuery)) { + const atCommandResult = await handleAtCommand({ + query: trimmedQuery, + config, + addItem, + onDebugMessage, + messageId: userMessageTimestamp, + signal, + }); + if (!atCommandResult.shouldProceed) { + return { queryToSend: null, shouldProceed: false }; + } + localQueryToSendToGemini = atCommandResult.processedQuery; + } else { + // Normal query for Gemini + addItem( + { type: MessageType.USER, text: trimmedQuery }, + userMessageTimestamp, + ); + localQueryToSendToGemini = trimmedQuery; + } + } else { + // It's a function response (PartListUnion that isn't a string) + localQueryToSendToGemini = query; + } + + if (localQueryToSendToGemini === null) { + onDebugMessage( + 'Query processing resulted in null, not sending to Gemini.', + ); + return { queryToSend: null, shouldProceed: false }; + } + return { queryToSend: localQueryToSendToGemini, shouldProceed: true }; + }; + + const ensureChatSession = async (): Promise<{ + client: GeminiClient | null; + chat: Chat | null; + }> => { + const currentClient = geminiClientRef.current; + if (!currentClient) { + const errorMsg = 'Gemini client is not available.'; + setInitError(errorMsg); + addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); + return { client: null, chat: null }; + } + + if (!chatSessionRef.current) { + try { + chatSessionRef.current = await currentClient.startChat(); + } catch (err: unknown) { + const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`; + setInitError(errorMsg); + addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); + setStreamingState(StreamingState.Idle); + return { client: currentClient, chat: null }; + } + } + return { client: currentClient, chat: chatSessionRef.current }; + }; + + // --- UI Helper Functions (used by event handlers) --- + const updateFunctionResponseUI = ( + toolResponse: ToolCallResponseInfo, + status: ToolCallStatus, + ) => { + setPendingHistoryItem((item) => + item?.type === 'tool_group' + ? { + ...item, + tools: item.tools.map((tool) => { + if (tool.callId === toolResponse.callId) { + return { + ...tool, + status, + resultDisplay: toolResponse.resultDisplay, + }; + } else { + return tool; + } + }), + } + : null, + ); + }; + + const updateConfirmingFunctionStatusUI = ( + callId: string, + confirmationDetails: ToolCallConfirmationDetails | undefined, + ) => { + if (pendingHistoryItemRef.current?.type !== 'tool_group') return; + setPendingHistoryItem((item) => + item?.type === 'tool_group' + ? { + ...item, + tools: item.tools.map((tool) => + tool.callId === callId + ? { + ...tool, + status: ToolCallStatus.Confirming, + confirmationDetails, + } + : tool, + ), + } + : null, + ); + }; + + // This function will be fully refactored in a later step + const wireConfirmationSubmission = ( + confirmationDetails: ServerToolCallConfirmationDetails, + ): ToolCallConfirmationDetails => { + const originalConfirmationDetails = confirmationDetails.details; + const request = confirmationDetails.request; + const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => { + originalConfirmationDetails.onConfirm(outcome); + if (pendingHistoryItemRef?.current?.type === 'tool_group') { + setPendingHistoryItem((item) => + item?.type === 'tool_group' + ? { + ...item, + tools: item.tools.map((tool) => + tool.callId === request.callId + ? { + ...tool, + confirmationDetails: undefined, + status: ToolCallStatus.Executing, + } + : tool, + ), + } + : item, + ); + refreshStatic(); + } + + if (outcome === ToolConfirmationOutcome.Cancel) { + declineToolExecution( + 'User rejected function call.', + ToolCallStatus.Error, + request, + originalConfirmationDetails, + ); + } else { + const tool = toolRegistry.getTool(request.name); + if (!tool) { + throw new Error( + `Tool "${request.name}" not found or is not registered.`, + ); + } + try { + abortControllerRef.current = new AbortController(); + const result = await tool.execute( + request.args, + abortControllerRef.current.signal, + ); + if (abortControllerRef.current.signal.aborted) { + declineToolExecution( + result.llmContent, + ToolCallStatus.Canceled, + request, + originalConfirmationDetails, + ); + return; + } + const functionResponse: Part = { + functionResponse: { + name: request.name, + id: request.callId, + response: { output: result.llmContent }, + }, + }; + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, Date.now()); + setPendingHistoryItem(null); + } + setStreamingState(StreamingState.Idle); + await submitQuery(functionResponse); // Recursive call + } finally { + if (streamingState !== StreamingState.WaitingForConfirmation) { + abortControllerRef.current = null; + } + } + } + }; + + // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure + // or could be a standalone helper if more params are passed. + function declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + request: ServerToolCallConfirmationDetails['request'], + originalDetails: ServerToolCallConfirmationDetails['details'], + ) { + let resultDisplay: ToolResultDisplay | undefined; + if ('fileDiff' in originalDetails) { + resultDisplay = { + fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, + }; + } else { + resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; + } + const functionResponse: Part = { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: declineMessage }, + }, + }; + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay, + error: new Error(declineMessage), + }; + const history = chatSessionRef.current?.getHistory(); + if (history) { + history.push({ role: 'model', parts: [functionResponse] }); + } + updateFunctionResponseUI(responseInfo, status); + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, Date.now()); + setPendingHistoryItem(null); + } + setStreamingState(StreamingState.Idle); + } + + return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm }; + }; + + // --- Stream Event Handlers --- + const handleContentEvent = ( + eventValue: ContentEvent['value'], + currentGeminiMessageBuffer: string, + userMessageTimestamp: number, + ): string => { + let newGeminiMessageBuffer = currentGeminiMessageBuffer + eventValue; + if ( + pendingHistoryItemRef.current?.type !== 'gemini' && + pendingHistoryItemRef.current?.type !== 'gemini_content' + ) { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } + setPendingHistoryItem({ type: 'gemini', text: '' }); + newGeminiMessageBuffer = eventValue; + } + // Split large messages for better rendering performance. Ideally, + // we should maximize the amount of output sent to . + const splitPoint = findLastSafeSplitPoint(newGeminiMessageBuffer); + if (splitPoint === newGeminiMessageBuffer.length) { + // Update the existing message with accumulated content + setPendingHistoryItem((item) => ({ + type: item?.type as 'gemini' | 'gemini_content', + text: newGeminiMessageBuffer, + })); + } else { + // This indicates that we need to split up this Gemini Message. + // Splitting a message is primarily a performance consideration. There is a + // component at the root of App.tsx which takes care of rendering + // content statically or dynamically. Everything but the last message is + // treated as static in order to prevent re-rendering an entire message history + // multiple times per-second (as streaming occurs). Prior to this change you'd + // see heavy flickering of the terminal. This ensures that larger messages get + // broken up so that there are more "statically" rendered. + const beforeText = newGeminiMessageBuffer.substring(0, splitPoint); + const afterText = newGeminiMessageBuffer.substring(splitPoint); + addItem( + { + type: pendingHistoryItemRef.current?.type as + | 'gemini' + | 'gemini_content', + text: beforeText, + }, + userMessageTimestamp, + ); + setPendingHistoryItem({ type: 'gemini_content', text: afterText }); + newGeminiMessageBuffer = afterText; + } + return newGeminiMessageBuffer; + }; + + const handleToolCallRequestEvent = ( + eventValue: ToolCallRequestEvent['value'], + userMessageTimestamp: number, + ) => { + const { callId, name, args } = eventValue; + const cliTool = toolRegistry.getTool(name); + if (!cliTool) { + console.error(`CLI Tool "${name}" not found!`); + return; // Skip this event if tool is not found + } + if (pendingHistoryItemRef.current?.type !== 'tool_group') { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } + setPendingHistoryItem({ type: 'tool_group', tools: [] }); + } + let description: string; + try { + description = cliTool.getDescription(args); + } catch (e) { + description = `Error: Unable to get description: ${getErrorMessage(e)}`; + } + const toolCallDisplay: IndividualToolCallDisplay = { + callId, + name: cliTool.displayName, + description, + status: ToolCallStatus.Pending, + resultDisplay: undefined, + confirmationDetails: undefined, + }; + setPendingHistoryItem((pending) => + pending?.type === 'tool_group' + ? { ...pending, tools: [...pending.tools, toolCallDisplay] } + : null, + ); + }; + + const handleToolCallResponseEvent = ( + eventValue: ToolCallResponseEvent['value'], + ) => { + const status = eventValue.error + ? ToolCallStatus.Error + : ToolCallStatus.Success; + updateFunctionResponseUI(eventValue, status); + }; + + const handleToolCallConfirmationEvent = ( + eventValue: ToolCallConfirmationEvent['value'], + ) => { + const confirmationDetails = wireConfirmationSubmission(eventValue); + updateConfirmingFunctionStatusUI( + eventValue.request.callId, + confirmationDetails, + ); + setStreamingState(StreamingState.WaitingForConfirmation); + }; + + const handleUserCancelledEvent = (userMessageTimestamp: number) => { + if (pendingHistoryItemRef.current) { + if (pendingHistoryItemRef.current.type === 'tool_group') { + const updatedTools = pendingHistoryItemRef.current.tools.map((tool) => + tool.status === ToolCallStatus.Pending || + tool.status === ToolCallStatus.Confirming || + tool.status === ToolCallStatus.Executing + ? { ...tool, status: ToolCallStatus.Canceled } + : tool, + ); + const pendingItem: HistoryItemToolGroup = { + ...pendingHistoryItemRef.current, + tools: updatedTools, + }; + addItem(pendingItem, userMessageTimestamp); + } else { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } + setPendingHistoryItem(null); + } + addItem( + { type: MessageType.INFO, text: 'User cancelled the request.' }, + userMessageTimestamp, + ); + setStreamingState(StreamingState.Idle); + }; + + const handleErrorEvent = ( + eventValue: ErrorEvent['value'], + userMessageTimestamp: number, + ) => { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); + } + addItem( + { type: MessageType.ERROR, text: `[API Error: ${eventValue.message}]` }, + userMessageTimestamp, + ); + }; + + const processGeminiStreamEvents = async ( + stream: AsyncIterable, + userMessageTimestamp: number, + ): Promise => { + let geminiMessageBuffer = ''; + + for await (const event of stream) { + if (event.type === ServerGeminiEventType.Content) { + geminiMessageBuffer = handleContentEvent( + event.value, + geminiMessageBuffer, + userMessageTimestamp, + ); + } else if (event.type === ServerGeminiEventType.ToolCallRequest) { + handleToolCallRequestEvent(event.value, userMessageTimestamp); + } else if (event.type === ServerGeminiEventType.ToolCallResponse) { + handleToolCallResponseEvent(event.value); + } else if (event.type === ServerGeminiEventType.ToolCallConfirmation) { + handleToolCallConfirmationEvent(event.value); + return StreamProcessingStatus.PausedForConfirmation; // Explicit return as this pauses the stream + } else if (event.type === ServerGeminiEventType.UserCancelled) { + handleUserCancelledEvent(userMessageTimestamp); + return StreamProcessingStatus.UserCancelled; + } else if (event.type === ServerGeminiEventType.Error) { + handleErrorEvent(event.value, userMessageTimestamp); + return StreamProcessingStatus.Error; + } + } + return StreamProcessingStatus.Completed; + }; + const submitQuery = useCallback( async (query: PartListUnion) => { if (streamingState === StreamingState.Responding) return; - if (typeof query === 'string' && query.trim().length === 0) return; const userMessageTimestamp = Date.now(); - let queryToSendToGemini: PartListUnion | null = null; - setShowHelp(false); abortControllerRef.current ??= new AbortController(); const signal = abortControllerRef.current.signal; - if (typeof query === 'string') { - const trimmedQuery = query.trim(); - onDebugMessage(`User query: '${trimmedQuery}'`); + const { queryToSend, shouldProceed } = await prepareQueryForGemini( + query, + userMessageTimestamp, + signal, + ); - if (handleSlashCommand(trimmedQuery)) return; - if (handleShellCommand(trimmedQuery)) return; - - if (isAtCommand(trimmedQuery)) { - const atCommandResult = await handleAtCommand({ - query: trimmedQuery, - config, - addItem, - onDebugMessage, - messageId: userMessageTimestamp, - signal, - }); - if (!atCommandResult.shouldProceed) return; - queryToSendToGemini = atCommandResult.processedQuery; - } else { - addItem( - { type: MessageType.USER, text: trimmedQuery }, - userMessageTimestamp, - ); - queryToSendToGemini = trimmedQuery; - } - } else { - queryToSendToGemini = query; - } - - if (queryToSendToGemini === null) { - onDebugMessage( - 'Query processing resulted in null, not sending to Gemini.', - ); + if (!shouldProceed || queryToSend === null) { return; } - const client = geminiClientRef.current; - if (!client) { - const errorMsg = 'Gemini client is not available.'; - setInitError(errorMsg); - addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); - return; - } + const { client, chat } = await ensureChatSession(); - if (!chatSessionRef.current) { - try { - chatSessionRef.current = await client.startChat(); - } catch (err: unknown) { - const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`; - setInitError(errorMsg); - addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); - setStreamingState(StreamingState.Idle); - return; - } + if (!client || !chat) { + return; } setStreamingState(StreamingState.Responding); setInitError(null); - const chat = chatSessionRef.current; try { - const stream = client.sendMessageStream( - chat, - queryToSendToGemini, - signal, + const stream = client.sendMessageStream(chat, queryToSend, signal); + const processingStatus = await processGeminiStreamEvents( + stream, + userMessageTimestamp, ); - let geminiMessageBuffer = ''; - - for await (const event of stream) { - if (event.type === ServerGeminiEventType.Content) { - if ( - pendingHistoryItemRef.current?.type !== 'gemini' && - pendingHistoryItemRef.current?.type !== 'gemini_content' - ) { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem({ - type: 'gemini', - text: '', - }); - geminiMessageBuffer = ''; - } - - geminiMessageBuffer += event.value; - - // Split large messages for better rendering performance. Ideally, - // we should maximize the amount of output sent to . - const splitPoint = findLastSafeSplitPoint(geminiMessageBuffer); - if (splitPoint === geminiMessageBuffer.length) { - // Update the existing message with accumulated content - setPendingHistoryItem((item) => ({ - type: item?.type as 'gemini' | 'gemini_content', - text: geminiMessageBuffer, - })); - } else { - // This indicates that we need to split up this Gemini Message. - // Splitting a message is primarily a performance consideration. There is a - // component at the root of App.tsx which takes care of rendering - // content statically or dynamically. Everything but the last message is - // treated as static in order to prevent re-rendering an entire message history - // multiple times per-second (as streaming occurs). Prior to this change you'd - // see heavy flickering of the terminal. This ensures that larger messages get - // broken up so that there are more "statically" rendered. - const beforeText = geminiMessageBuffer.substring(0, splitPoint); - const afterText = geminiMessageBuffer.substring(splitPoint); - geminiMessageBuffer = afterText; - addItem( - { - type: pendingHistoryItemRef.current?.type as - | 'gemini' - | 'gemini_content', - text: beforeText, - }, - userMessageTimestamp, - ); - setPendingHistoryItem({ - type: 'gemini_content', - text: afterText, - }); - } - } else if (event.type === ServerGeminiEventType.ToolCallRequest) { - const { callId, name, args } = event.value; - const cliTool = toolRegistry.getTool(name); - if (!cliTool) { - console.error(`CLI Tool "${name}" not found!`); - continue; - } - - if (pendingHistoryItemRef.current?.type !== 'tool_group') { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem({ - type: 'tool_group', - tools: [], - }); - } - - let description: string; - try { - description = cliTool.getDescription(args); - } catch (e) { - description = `Error: Unable to get description: ${getErrorMessage(e)}`; - } - - const toolCallDisplay: IndividualToolCallDisplay = { - callId, - name: cliTool.displayName, - description, - status: ToolCallStatus.Pending, - resultDisplay: undefined, - confirmationDetails: undefined, - }; - - setPendingHistoryItem((pending) => - pending?.type === 'tool_group' - ? { - ...pending, - tools: [...pending.tools, toolCallDisplay], - } - : null, - ); - } else if (event.type === ServerGeminiEventType.ToolCallResponse) { - const status = event.value.error - ? ToolCallStatus.Error - : ToolCallStatus.Success; - updateFunctionResponseUI(event.value, status); - } else if ( - event.type === ServerGeminiEventType.ToolCallConfirmation - ) { - const confirmationDetails = wireConfirmationSubmission(event.value); - updateConfirmingFunctionStatusUI( - event.value.request.callId, - confirmationDetails, - ); - setStreamingState(StreamingState.WaitingForConfirmation); - return; - } else if (event.type === ServerGeminiEventType.UserCancelled) { - if (pendingHistoryItemRef.current) { - if (pendingHistoryItemRef.current.type === 'tool_group') { - const updatedTools = pendingHistoryItemRef.current.tools.map( - (tool) => { - if ( - tool.status === ToolCallStatus.Pending || - tool.status === ToolCallStatus.Confirming || - tool.status === ToolCallStatus.Executing - ) { - return { ...tool, status: ToolCallStatus.Canceled }; - } - return tool; - }, - ); - const pendingHistoryItem = pendingHistoryItemRef.current; - pendingHistoryItem.tools = updatedTools; - addItem(pendingHistoryItem, userMessageTimestamp); - } else { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem(null); - } - addItem( - { type: MessageType.INFO, text: 'User cancelled the request.' }, - userMessageTimestamp, - ); - setStreamingState(StreamingState.Idle); - return; - } else if (event.type === ServerGeminiEventType.Error) { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - setPendingHistoryItem(null); - } - addItem( - { - type: MessageType.ERROR, - text: `[API Error: ${event.value.message}]`, - }, - userMessageTimestamp, - ); - } + if ( + processingStatus === StreamProcessingStatus.PausedForConfirmation || + processingStatus === StreamProcessingStatus.UserCancelled + ) { + return; } if (pendingHistoryItemRef.current) { @@ -323,7 +591,12 @@ export const useGeminiStream = ( setPendingHistoryItem(null); } - setStreamingState(StreamingState.Idle); + if ( + processingStatus === StreamProcessingStatus.Completed || + processingStatus === StreamProcessingStatus.Error + ) { + setStreamingState(StreamingState.Idle); + } } catch (error: unknown) { if (!isNodeError(error) || error.name !== 'AbortError') { addItem( @@ -336,191 +609,12 @@ export const useGeminiStream = ( } setStreamingState(StreamingState.Idle); } finally { - abortControllerRef.current = null; - } - - function updateConfirmingFunctionStatusUI( - callId: string, - confirmationDetails: ToolCallConfirmationDetails | undefined, - ) { - if (pendingHistoryItemRef.current?.type !== 'tool_group') return; - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails, - } - : tool, - ), - } - : null, - ); - } - - function updateFunctionResponseUI( - toolResponse: ToolCallResponseInfo, - status: ToolCallStatus, - ) { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => { - if (tool.callId === toolResponse.callId) { - return { - ...tool, - status, - resultDisplay: toolResponse.resultDisplay, - }; - } else { - return tool; - } - }), - } - : null, - ); - } - - function wireConfirmationSubmission( - confirmationDetails: ServerToolCallConfirmationDetails, - ): ToolCallConfirmationDetails { - const originalConfirmationDetails = confirmationDetails.details; - const request = confirmationDetails.request; - const resubmittingConfirm = async ( - outcome: ToolConfirmationOutcome, - ) => { - originalConfirmationDetails.onConfirm(outcome); - - if (pendingHistoryItemRef?.current?.type === 'tool_group') { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === request.callId - ? { - ...tool, - confirmationDetails: undefined, - status: ToolCallStatus.Executing, - } - : tool, - ), - } - : item, - ); - refreshStatic(); - } - - if (outcome === ToolConfirmationOutcome.Cancel) { - declineToolExecution( - 'User rejected function call.', - ToolCallStatus.Error, - ); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - - try { - abortControllerRef.current = new AbortController(); - const result = await tool.execute( - request.args, - abortControllerRef.current.signal, - ); - - if (abortControllerRef.current.signal.aborted) { - declineToolExecution( - result.llmContent, - ToolCallStatus.Canceled, - ); - return; - } - - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, - }; - - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setStreamingState(StreamingState.Idle); - await submitQuery(functionResponse); - } finally { - abortControllerRef.current = null; - } - } - - function declineToolExecution( - declineMessage: string, - status: ToolCallStatus, - ) { - let resultDisplay: ToolResultDisplay | undefined; - if ('fileDiff' in originalConfirmationDetails) { - resultDisplay = { - fileDiff: ( - originalConfirmationDetails as ToolEditConfirmationDetails - ).fileDiff, - }; - } else { - resultDisplay = `~~${(originalConfirmationDetails as ToolExecuteConfirmationDetails).command}~~`; - } - const functionResponse: Part = { - functionResponse: { - id: request.callId, - name: request.name, - response: { error: declineMessage }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay, - error: new Error(declineMessage), - }; - - const history = chatSessionRef.current?.getHistory(); - if (history) { - history.push({ - role: 'model', - parts: [functionResponse], - }); - } - - updateFunctionResponseUI(responseInfo, status); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setStreamingState(StreamingState.Idle); - } - }; - - return { - ...originalConfirmationDetails, - onConfirm: resubmittingConfirm, - }; + if (streamingState !== StreamingState.WaitingForConfirmation) { + abortControllerRef.current = null; + } } }, + // eslint-disable-next-line react-hooks/exhaustive-deps [ streamingState, setShowHelp, @@ -528,11 +622,10 @@ export const useGeminiStream = ( handleShellCommand, config, addItem, - pendingHistoryItemRef, - setPendingHistoryItem, - toolRegistry, - refreshStatic, onDebugMessage, + refreshStatic, + setInitError, + setStreamingState, ], ); diff --git a/packages/cli/src/ui/types.ts b/packages/cli/src/ui/types.ts index a2102418..60508c05 100644 --- a/packages/cli/src/ui/types.ts +++ b/packages/cli/src/ui/types.ts @@ -55,18 +55,47 @@ export interface HistoryItemBase { text?: string; // Text content for user/gemini/info/error messages } +export type HistoryItemUser = HistoryItemBase & { + type: 'user'; + text: string; +}; + +export type HistoryItemGemini = HistoryItemBase & { + type: 'gemini'; + text: string; +}; + +export type HistoryItemGeminiContent = HistoryItemBase & { + type: 'gemini_content'; + text: string; +}; + +export type HistoryItemInfo = HistoryItemBase & { + type: 'info'; + text: string; +}; + +export type HistoryItemError = HistoryItemBase & { + type: 'error'; + text: string; +}; + +export type HistoryItemToolGroup = HistoryItemBase & { + type: 'tool_group'; + tools: IndividualToolCallDisplay[]; +}; + // Using Omit seems to have some issues with typescript's // type inference e.g. historyItem.type === 'tool_group' isn't auto-inferring that // 'tools' in historyItem. -export type HistoryItemWithoutId = HistoryItemBase & - ( - | { type: 'user'; text: string } - | { type: 'gemini'; text: string } - | { type: 'gemini_content'; text: string } - | { type: 'info'; text: string } - | { type: 'error'; text: string } - | { type: 'tool_group'; tools: IndividualToolCallDisplay[] } - ); +// Individually exported types extending HistoryItemBase +export type HistoryItemWithoutId = + | HistoryItemUser + | HistoryItemGemini + | HistoryItemGeminiContent + | HistoryItemInfo + | HistoryItemError + | HistoryItemToolGroup; export type HistoryItem = HistoryItemWithoutId & { id: number }; diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index 8f473986..dd2a08ce 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -78,16 +78,43 @@ export interface ServerToolCallConfirmationDetails { details: ToolCallConfirmationDetails; } +export type ServerGeminiContentEvent = { + type: GeminiEventType.Content; + value: string; +}; + +export type ServerGeminiToolCallRequestEvent = { + type: GeminiEventType.ToolCallRequest; + value: ToolCallRequestInfo; +}; + +export type ServerGeminiToolCallResponseEvent = { + type: GeminiEventType.ToolCallResponse; + value: ToolCallResponseInfo; +}; + +export type ServerGeminiToolCallConfirmationEvent = { + type: GeminiEventType.ToolCallConfirmation; + value: ServerToolCallConfirmationDetails; +}; + +export type ServerGeminiUserCancelledEvent = { + type: GeminiEventType.UserCancelled; +}; + +export type ServerGeminiErrorEvent = { + type: GeminiEventType.Error; + value: GeminiErrorEventValue; +}; + +// The original union type, now composed of the individual types export type ServerGeminiStreamEvent = - | { type: GeminiEventType.Content; value: string } - | { type: GeminiEventType.ToolCallRequest; value: ToolCallRequestInfo } - | { type: GeminiEventType.ToolCallResponse; value: ToolCallResponseInfo } - | { - type: GeminiEventType.ToolCallConfirmation; - value: ServerToolCallConfirmationDetails; - } - | { type: GeminiEventType.UserCancelled } - | { type: GeminiEventType.Error; value: GeminiErrorEventValue }; + | ServerGeminiContentEvent + | ServerGeminiToolCallRequestEvent + | ServerGeminiToolCallResponseEvent + | ServerGeminiToolCallConfirmationEvent + | ServerGeminiUserCancelledEvent + | ServerGeminiErrorEvent; // A turn manages the agentic loop turn within the server context. export class Turn {