From 738c2692fb9bdeda7c801e7b4e773f81ebc1ded0 Mon Sep 17 00:00:00 2001 From: Taylor Mullen Date: Mon, 21 Apr 2025 14:32:18 -0400 Subject: [PATCH] Fix confirmations. - This fixes what it means to get confirmations in GC. Prior to this they had just been accidentally unwired as part of all of the refactorings to turns + to server/core. - The key piece of this is that we wrap the onConfirm in the gemini stream hook in order to resubmit function responses. This isn't 100% ideal but gets the job done for now. - Fixed history not updating properly with confirmations. Fixes https://b.corp.google.com/issues/412323656 --- packages/cli/src/ui/App.tsx | 70 ++++---- .../ui/components/messages/ToolMessage.tsx | 34 ---- packages/cli/src/ui/hooks/useGeminiStream.ts | 161 ++++++++++++++---- packages/server/src/core/turn.ts | 141 ++++++++------- 4 files changed, 226 insertions(+), 180 deletions(-) diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 42b0ae5b..724036e2 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -52,15 +52,7 @@ export const App = ({ config }: AppProps) => { [history], ); - const isWaitingForToolConfirmation = history.some( - (item) => - item.type === 'tool_group' && - item.tools.some((tool) => tool.confirmationDetails !== undefined), - ); - const isInputActive = - streamingState === StreamingState.Idle && - !initError && - !isWaitingForToolConfirmation; + const isInputActive = streamingState === StreamingState.Idle && !initError; const { query, handleSubmit: handleHistorySubmit } = useInputHistory({ userMessages, @@ -88,39 +80,37 @@ export const App = ({ config }: AppProps) => { )} - {initError && - streamingState !== StreamingState.Responding && - !isWaitingForToolConfirmation && ( - - {history.find( - (item) => item.type === 'error' && item.text?.includes(initError), - )?.text ? ( + {initError && streamingState !== StreamingState.Responding && ( + + {history.find( + (item) => item.type === 'error' && item.text?.includes(initError), + )?.text ? ( + + { + history.find( + (item) => + item.type === 'error' && item.text?.includes(initError), + )?.text + } + + ) : ( + <> - { - history.find( - (item) => - item.type === 'error' && item.text?.includes(initError), - )?.text - } + Initialization Error: {initError} - ) : ( - <> - - Initialization Error: {initError} - - - {' '} - Please check API key and configuration. - - - )} - - )} + + {' '} + Please check API key and configuration. + + + )} + + )} diff --git a/packages/cli/src/ui/components/messages/ToolMessage.tsx b/packages/cli/src/ui/components/messages/ToolMessage.tsx index f33ed6cb..f21e1d28 100644 --- a/packages/cli/src/ui/components/messages/ToolMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolMessage.tsx @@ -11,11 +11,6 @@ import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js'; import { DiffRenderer } from './DiffRenderer.js'; import { FileDiff, ToolResultDisplay } from '../../../tools/tools.js'; import { Colors } from '../../colors.js'; -import { - ToolCallConfirmationDetails, - ToolEditConfirmationDetails, - ToolExecuteConfirmationDetails, -} from '@gemini-code/server'; export const ToolMessage: React.FC = ({ callId, @@ -23,12 +18,7 @@ export const ToolMessage: React.FC = ({ description, resultDisplay, status, - confirmationDetails, }) => { - // Explicitly type the props to help the type checker - const typedConfirmationDetails = confirmationDetails as - | ToolCallConfirmationDetails - | undefined; const typedResultDisplay = resultDisplay as ToolResultDisplay | undefined; let color = Colors.SubtleComment; @@ -78,30 +68,6 @@ export const ToolMessage: React.FC = ({ : ` - ${description}`} - {status === ToolCallStatus.Confirming && typedConfirmationDetails && ( - - {/* Display diff for edit/write */} - {'fileDiff' in typedConfirmationDetails && ( - - )} - {/* Display command for execute */} - {'command' in typedConfirmationDetails && ( - - Command:{' '} - { - (typedConfirmationDetails as ToolExecuteConfirmationDetails) - .command - } - - )} - {/* */} - - )} {status === ToolCallStatus.Success && typedResultDisplay && ( {typeof typedResultDisplay === 'string' ? ( diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 585554ee..62851019 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -17,8 +17,18 @@ import { Config, ToolCallConfirmationDetails, ToolCallResponseInfo, + ServerToolCallConfirmationDetails, + ToolConfirmationOutcome, + ToolResultDisplay, + ToolEditConfirmationDetails, + ToolExecuteConfirmationDetails, } from '@gemini-code/server'; -import type { Chat, PartListUnion, FunctionDeclaration } from '@google/genai'; +import { + type Chat, + type PartListUnion, + type FunctionDeclaration, + type Part, +} from '@google/genai'; import { HistoryItem, IndividualToolCallDisplay, @@ -286,36 +296,24 @@ export const useGeminiStream = ( }), ); } else if (event.type === ServerGeminiEventType.ToolCallResponse) { - updateFunctionResponseUI(event.value); + const status = event.value.error + ? ToolCallStatus.Error + : ToolCallStatus.Success; + updateFunctionResponseUI(event.value, status); } else if ( event.type === ServerGeminiEventType.ToolCallConfirmation ) { - setHistory((prevHistory) => - prevHistory.map((item) => { - if ( - item.id === currentToolGroupId && - item.type === 'tool_group' - ) { - return { - ...item, - tools: item.tools.map((tool) => - tool.callId === event.value.request.callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails: event.value.details, - } - : tool, - ), - }; - } - return item; - }), + const confirmationDetails = wireConfirmationSubmission(event.value); + updateConfirmingFunctionStatusUI( + event.value.request.callId, + confirmationDetails, ); setStreamingState(StreamingState.WaitingForConfirmation); return; } } + + setStreamingState(StreamingState.Idle); } catch (error: unknown) { if (!isNodeError(error) || error.name !== 'AbortError') { console.error('Error processing stream or executing tool:', error); @@ -328,16 +326,40 @@ export const useGeminiStream = ( getNextMessageId(userMessageTimestamp), ); } + setStreamingState(StreamingState.Idle); } finally { abortControllerRef.current = null; - // Only set to Idle if not waiting for confirmation. - // Passthrough commands handle their own Idle transition. - if (streamingState !== StreamingState.WaitingForConfirmation) { - setStreamingState(StreamingState.Idle); - } } - function updateFunctionResponseUI(toolResponse: ToolCallResponseInfo) { + function updateConfirmingFunctionStatusUI( + callId: string, + confirmationDetails: ToolCallConfirmationDetails | undefined, + ) { + setHistory((prevHistory) => + prevHistory.map((item) => { + if (item.id === currentToolGroupId && item.type === 'tool_group') { + return { + ...item, + tools: item.tools.map((tool) => + tool.callId === callId + ? { + ...tool, + status: ToolCallStatus.Confirming, + confirmationDetails, + } + : tool, + ), + }; + } + return item; + }), + ); + } + + function updateFunctionResponseUI( + toolResponse: ToolCallResponseInfo, + status: ToolCallStatus, + ) { setHistory((prevHistory) => prevHistory.map((item) => { if (item.id === currentToolGroupId && item.type === 'tool_group') { @@ -347,10 +369,7 @@ export const useGeminiStream = ( if (tool.callId === toolResponse.callId) { return { ...tool, - // TODO: Do we surface the error here? - status: toolResponse.error - ? ToolCallStatus.Error - : ToolCallStatus.Success, + status, resultDisplay: toolResponse.resultDisplay, }; } else { @@ -363,6 +382,82 @@ export const useGeminiStream = ( }), ); } + + function wireConfirmationSubmission( + confirmationDetails: ServerToolCallConfirmationDetails, + ): ToolCallConfirmationDetails { + const originalConfirmationDetails = confirmationDetails.details; + const request = confirmationDetails.request; + const resubmittingConfirm = async ( + outcome: ToolConfirmationOutcome, + ) => { + originalConfirmationDetails.onConfirm(outcome); + + // Reset streaming state since confirmation has been chosen. + setStreamingState(StreamingState.Idle); + + if (outcome === ToolConfirmationOutcome.Cancel) { + let resultDisplay: ToolResultDisplay | undefined; + if ('fileDiff' in originalConfirmationDetails) { + resultDisplay = { + fileDiff: ( + originalConfirmationDetails as ToolEditConfirmationDetails + ).fileDiff, + }; + } else { + resultDisplay = `~~${(originalConfirmationDetails as ToolExecuteConfirmationDetails).command}~~`; + } + const functionResponse: Part = { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: 'User rejected function call.' }, + }, + }; + + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay, + error: undefined, + }; + + updateFunctionResponseUI(responseInfo, ToolCallStatus.Error); + + await submitQuery(functionResponse); + } else { + const tool = toolRegistry.getTool(request.name); + if (!tool) { + throw new Error( + `Tool "${request.name}" not found or is not registered.`, + ); + } + const result = await tool.execute(request.args); + const functionResponse: Part = { + functionResponse: { + name: request.name, + id: request.callId, + response: { output: result.llmContent }, + }, + }; + + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + + await submitQuery(functionResponse); + } + }; + + return { + ...originalConfirmationDetails, + onConfirm: resubmittingConfirm, + }; + } }, // Dependencies need careful review - including updateGeminiMessage [ diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index 0a1c594c..31656466 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -130,83 +130,78 @@ export class Turn { yield event; } } + } - // Execute pending tool calls - const toolPromises = this.pendingToolCalls.map( - async (pendingToolCall): Promise => { - const tool = this.availableTools.get(pendingToolCall.name); - if (!tool) { - return { - ...pendingToolCall, - error: new Error( - `Tool "${pendingToolCall.name}" not found or not provided to Turn.`, - ), - confirmationDetails: undefined, - }; - } - - try { - const confirmationDetails = await tool.shouldConfirmExecute( - pendingToolCall.args, - ); - if (confirmationDetails) { - return { ...pendingToolCall, confirmationDetails }; - } else { - const result = await tool.execute(pendingToolCall.args); - return { - ...pendingToolCall, - result, - confirmationDetails: undefined, - }; - } - } catch (execError: unknown) { - return { - ...pendingToolCall, - error: new Error( - `Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`, - ), - confirmationDetails: undefined, - }; - } - }, - ); - const outcomes = await Promise.all(toolPromises); - - // Process outcomes and prepare function responses - this.pendingToolCalls = []; // Clear pending calls for this turn - - for (let i = 0; i < outcomes.length; i++) { - const outcome = outcomes[i]; - if (outcome.confirmationDetails) { - this.confirmationDetails.push(outcome.confirmationDetails); - const serverConfirmationetails: ServerToolCallConfirmationDetails = { - request: { - callId: outcome.callId, - name: outcome.name, - args: outcome.args, - }, - details: outcome.confirmationDetails, + // Execute pending tool calls + const toolPromises = this.pendingToolCalls.map( + async (pendingToolCall): Promise => { + const tool = this.availableTools.get(pendingToolCall.name); + if (!tool) { + return { + ...pendingToolCall, + error: new Error( + `Tool "${pendingToolCall.name}" not found or not provided to Turn.`, + ), + confirmationDetails: undefined, }; - yield { - type: GeminiEventType.ToolCallConfirmation, - value: serverConfirmationetails, - }; - } else { - const responsePart = this.buildFunctionResponse(outcome); - this.fnResponses.push(responsePart); - const responseInfo: ToolCallResponseInfo = { - callId: outcome.callId, - responsePart, - resultDisplay: outcome.result?.returnDisplay, - error: outcome.error, - }; - yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; } - } - // If there were function responses, the caller (GeminiService) will loop - // and call run() again with these responses. - // If no function responses, the turn ends here. + try { + const confirmationDetails = await tool.shouldConfirmExecute( + pendingToolCall.args, + ); + if (confirmationDetails) { + return { ...pendingToolCall, confirmationDetails }; + } else { + const result = await tool.execute(pendingToolCall.args); + return { + ...pendingToolCall, + result, + confirmationDetails: undefined, + }; + } + } catch (execError: unknown) { + return { + ...pendingToolCall, + error: new Error( + `Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`, + ), + confirmationDetails: undefined, + }; + } + }, + ); + const outcomes = await Promise.all(toolPromises); + + // Process outcomes and prepare function responses + this.pendingToolCalls = []; // Clear pending calls for this turn + + for (const outcome of outcomes) { + if (outcome.confirmationDetails) { + this.confirmationDetails.push(outcome.confirmationDetails); + const serverConfirmationetails: ServerToolCallConfirmationDetails = { + request: { + callId: outcome.callId, + name: outcome.name, + args: outcome.args, + }, + details: outcome.confirmationDetails, + }; + yield { + type: GeminiEventType.ToolCallConfirmation, + value: serverConfirmationetails, + }; + } else { + const responsePart = this.buildFunctionResponse(outcome); + this.fnResponses.push(responsePart); + const responseInfo: ToolCallResponseInfo = { + callId: outcome.callId, + responsePart, + resultDisplay: outcome.result?.returnDisplay, + error: outcome.error, + }; + yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; + } } }