Fix confirmations.

- This fixes what it means to get confirmations in GC. Prior to this they had just been accidentally unwired as part of all of the refactorings to turns + to server/core.
  - The key piece of this is that we wrap the onConfirm in the gemini stream hook in order to resubmit function responses. This isn't 100% ideal but gets the job done for now.
- Fixed history not updating properly with confirmations.

Fixes https://b.corp.google.com/issues/412323656
This commit is contained in:
Taylor Mullen 2025-04-21 14:32:18 -04:00 committed by N. Taylor Mullen
parent 618f8a43cf
commit 738c2692fb
4 changed files with 226 additions and 180 deletions

View File

@ -52,15 +52,7 @@ export const App = ({ config }: AppProps) => {
[history],
);
const isWaitingForToolConfirmation = history.some(
(item) =>
item.type === 'tool_group' &&
item.tools.some((tool) => tool.confirmationDetails !== undefined),
);
const isInputActive =
streamingState === StreamingState.Idle &&
!initError &&
!isWaitingForToolConfirmation;
const isInputActive = streamingState === StreamingState.Idle && !initError;
const { query, handleSubmit: handleHistorySubmit } = useInputHistory({
userMessages,
@ -88,39 +80,37 @@ export const App = ({ config }: AppProps) => {
</Box>
)}
{initError &&
streamingState !== StreamingState.Responding &&
!isWaitingForToolConfirmation && (
<Box
borderStyle="round"
borderColor={Colors.AccentRed}
paddingX={1}
marginBottom={1}
>
{history.find(
(item) => item.type === 'error' && item.text?.includes(initError),
)?.text ? (
{initError && streamingState !== StreamingState.Responding && (
<Box
borderStyle="round"
borderColor={Colors.AccentRed}
paddingX={1}
marginBottom={1}
>
{history.find(
(item) => item.type === 'error' && item.text?.includes(initError),
)?.text ? (
<Text color={Colors.AccentRed}>
{
history.find(
(item) =>
item.type === 'error' && item.text?.includes(initError),
)?.text
}
</Text>
) : (
<>
<Text color={Colors.AccentRed}>
{
history.find(
(item) =>
item.type === 'error' && item.text?.includes(initError),
)?.text
}
Initialization Error: {initError}
</Text>
) : (
<>
<Text color={Colors.AccentRed}>
Initialization Error: {initError}
</Text>
<Text color={Colors.AccentRed}>
{' '}
Please check API key and configuration.
</Text>
</>
)}
</Box>
)}
<Text color={Colors.AccentRed}>
{' '}
Please check API key and configuration.
</Text>
</>
)}
</Box>
)}
<Box flexDirection="column">
<HistoryDisplay history={history} onSubmit={submitQuery} />

View File

@ -11,11 +11,6 @@ import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js';
import { DiffRenderer } from './DiffRenderer.js';
import { FileDiff, ToolResultDisplay } from '../../../tools/tools.js';
import { Colors } from '../../colors.js';
import {
ToolCallConfirmationDetails,
ToolEditConfirmationDetails,
ToolExecuteConfirmationDetails,
} from '@gemini-code/server';
export const ToolMessage: React.FC<IndividualToolCallDisplay> = ({
callId,
@ -23,12 +18,7 @@ export const ToolMessage: React.FC<IndividualToolCallDisplay> = ({
description,
resultDisplay,
status,
confirmationDetails,
}) => {
// Explicitly type the props to help the type checker
const typedConfirmationDetails = confirmationDetails as
| ToolCallConfirmationDetails
| undefined;
const typedResultDisplay = resultDisplay as ToolResultDisplay | undefined;
let color = Colors.SubtleComment;
@ -78,30 +68,6 @@ export const ToolMessage: React.FC<IndividualToolCallDisplay> = ({
: ` - ${description}`}
</Text>
</Box>
{status === ToolCallStatus.Confirming && typedConfirmationDetails && (
<Box flexDirection="column" marginLeft={2}>
{/* Display diff for edit/write */}
{'fileDiff' in typedConfirmationDetails && (
<DiffRenderer
diffContent={
(typedConfirmationDetails as ToolEditConfirmationDetails)
.fileDiff
}
/>
)}
{/* Display command for execute */}
{'command' in typedConfirmationDetails && (
<Text color={Colors.AccentYellow}>
Command:{' '}
{
(typedConfirmationDetails as ToolExecuteConfirmationDetails)
.command
}
</Text>
)}
{/* <ConfirmInput onConfirm={handleConfirm} isFocused={isFocused} /> */}
</Box>
)}
{status === ToolCallStatus.Success && typedResultDisplay && (
<Box flexDirection="column" marginLeft={2}>
{typeof typedResultDisplay === 'string' ? (

View File

@ -17,8 +17,18 @@ import {
Config,
ToolCallConfirmationDetails,
ToolCallResponseInfo,
ServerToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolResultDisplay,
ToolEditConfirmationDetails,
ToolExecuteConfirmationDetails,
} from '@gemini-code/server';
import type { Chat, PartListUnion, FunctionDeclaration } from '@google/genai';
import {
type Chat,
type PartListUnion,
type FunctionDeclaration,
type Part,
} from '@google/genai';
import {
HistoryItem,
IndividualToolCallDisplay,
@ -286,36 +296,24 @@ export const useGeminiStream = (
}),
);
} else if (event.type === ServerGeminiEventType.ToolCallResponse) {
updateFunctionResponseUI(event.value);
const status = event.value.error
? ToolCallStatus.Error
: ToolCallStatus.Success;
updateFunctionResponseUI(event.value, status);
} else if (
event.type === ServerGeminiEventType.ToolCallConfirmation
) {
setHistory((prevHistory) =>
prevHistory.map((item) => {
if (
item.id === currentToolGroupId &&
item.type === 'tool_group'
) {
return {
...item,
tools: item.tools.map((tool) =>
tool.callId === event.value.request.callId
? {
...tool,
status: ToolCallStatus.Confirming,
confirmationDetails: event.value.details,
}
: tool,
),
};
}
return item;
}),
const confirmationDetails = wireConfirmationSubmission(event.value);
updateConfirmingFunctionStatusUI(
event.value.request.callId,
confirmationDetails,
);
setStreamingState(StreamingState.WaitingForConfirmation);
return;
}
}
setStreamingState(StreamingState.Idle);
} catch (error: unknown) {
if (!isNodeError(error) || error.name !== 'AbortError') {
console.error('Error processing stream or executing tool:', error);
@ -328,16 +326,40 @@ export const useGeminiStream = (
getNextMessageId(userMessageTimestamp),
);
}
setStreamingState(StreamingState.Idle);
} finally {
abortControllerRef.current = null;
// Only set to Idle if not waiting for confirmation.
// Passthrough commands handle their own Idle transition.
if (streamingState !== StreamingState.WaitingForConfirmation) {
setStreamingState(StreamingState.Idle);
}
}
function updateFunctionResponseUI(toolResponse: ToolCallResponseInfo) {
function updateConfirmingFunctionStatusUI(
callId: string,
confirmationDetails: ToolCallConfirmationDetails | undefined,
) {
setHistory((prevHistory) =>
prevHistory.map((item) => {
if (item.id === currentToolGroupId && item.type === 'tool_group') {
return {
...item,
tools: item.tools.map((tool) =>
tool.callId === callId
? {
...tool,
status: ToolCallStatus.Confirming,
confirmationDetails,
}
: tool,
),
};
}
return item;
}),
);
}
function updateFunctionResponseUI(
toolResponse: ToolCallResponseInfo,
status: ToolCallStatus,
) {
setHistory((prevHistory) =>
prevHistory.map((item) => {
if (item.id === currentToolGroupId && item.type === 'tool_group') {
@ -347,10 +369,7 @@ export const useGeminiStream = (
if (tool.callId === toolResponse.callId) {
return {
...tool,
// TODO: Do we surface the error here?
status: toolResponse.error
? ToolCallStatus.Error
: ToolCallStatus.Success,
status,
resultDisplay: toolResponse.resultDisplay,
};
} else {
@ -363,6 +382,82 @@ export const useGeminiStream = (
}),
);
}
function wireConfirmationSubmission(
confirmationDetails: ServerToolCallConfirmationDetails,
): ToolCallConfirmationDetails {
const originalConfirmationDetails = confirmationDetails.details;
const request = confirmationDetails.request;
const resubmittingConfirm = async (
outcome: ToolConfirmationOutcome,
) => {
originalConfirmationDetails.onConfirm(outcome);
// Reset streaming state since confirmation has been chosen.
setStreamingState(StreamingState.Idle);
if (outcome === ToolConfirmationOutcome.Cancel) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalConfirmationDetails) {
resultDisplay = {
fileDiff: (
originalConfirmationDetails as ToolEditConfirmationDetails
).fileDiff,
};
} else {
resultDisplay = `~~${(originalConfirmationDetails as ToolExecuteConfirmationDetails).command}~~`;
}
const functionResponse: Part = {
functionResponse: {
id: request.callId,
name: request.name,
response: { error: 'User rejected function call.' },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
error: undefined,
};
updateFunctionResponseUI(responseInfo, ToolCallStatus.Error);
await submitQuery(functionResponse);
} else {
const tool = toolRegistry.getTool(request.name);
if (!tool) {
throw new Error(
`Tool "${request.name}" not found or is not registered.`,
);
}
const result = await tool.execute(request.args);
const functionResponse: Part = {
functionResponse: {
name: request.name,
id: request.callId,
response: { output: result.llmContent },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
await submitQuery(functionResponse);
}
};
return {
...originalConfirmationDetails,
onConfirm: resubmittingConfirm,
};
}
},
// Dependencies need careful review - including updateGeminiMessage
[

View File

@ -130,83 +130,78 @@ export class Turn {
yield event;
}
}
}
// Execute pending tool calls
const toolPromises = this.pendingToolCalls.map(
async (pendingToolCall): Promise<ServerToolExecutionOutcome> => {
const tool = this.availableTools.get(pendingToolCall.name);
if (!tool) {
return {
...pendingToolCall,
error: new Error(
`Tool "${pendingToolCall.name}" not found or not provided to Turn.`,
),
confirmationDetails: undefined,
};
}
try {
const confirmationDetails = await tool.shouldConfirmExecute(
pendingToolCall.args,
);
if (confirmationDetails) {
return { ...pendingToolCall, confirmationDetails };
} else {
const result = await tool.execute(pendingToolCall.args);
return {
...pendingToolCall,
result,
confirmationDetails: undefined,
};
}
} catch (execError: unknown) {
return {
...pendingToolCall,
error: new Error(
`Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`,
),
confirmationDetails: undefined,
};
}
},
);
const outcomes = await Promise.all(toolPromises);
// Process outcomes and prepare function responses
this.pendingToolCalls = []; // Clear pending calls for this turn
for (let i = 0; i < outcomes.length; i++) {
const outcome = outcomes[i];
if (outcome.confirmationDetails) {
this.confirmationDetails.push(outcome.confirmationDetails);
const serverConfirmationetails: ServerToolCallConfirmationDetails = {
request: {
callId: outcome.callId,
name: outcome.name,
args: outcome.args,
},
details: outcome.confirmationDetails,
// Execute pending tool calls
const toolPromises = this.pendingToolCalls.map(
async (pendingToolCall): Promise<ServerToolExecutionOutcome> => {
const tool = this.availableTools.get(pendingToolCall.name);
if (!tool) {
return {
...pendingToolCall,
error: new Error(
`Tool "${pendingToolCall.name}" not found or not provided to Turn.`,
),
confirmationDetails: undefined,
};
yield {
type: GeminiEventType.ToolCallConfirmation,
value: serverConfirmationetails,
};
} else {
const responsePart = this.buildFunctionResponse(outcome);
this.fnResponses.push(responsePart);
const responseInfo: ToolCallResponseInfo = {
callId: outcome.callId,
responsePart,
resultDisplay: outcome.result?.returnDisplay,
error: outcome.error,
};
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
}
}
// If there were function responses, the caller (GeminiService) will loop
// and call run() again with these responses.
// If no function responses, the turn ends here.
try {
const confirmationDetails = await tool.shouldConfirmExecute(
pendingToolCall.args,
);
if (confirmationDetails) {
return { ...pendingToolCall, confirmationDetails };
} else {
const result = await tool.execute(pendingToolCall.args);
return {
...pendingToolCall,
result,
confirmationDetails: undefined,
};
}
} catch (execError: unknown) {
return {
...pendingToolCall,
error: new Error(
`Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`,
),
confirmationDetails: undefined,
};
}
},
);
const outcomes = await Promise.all(toolPromises);
// Process outcomes and prepare function responses
this.pendingToolCalls = []; // Clear pending calls for this turn
for (const outcome of outcomes) {
if (outcome.confirmationDetails) {
this.confirmationDetails.push(outcome.confirmationDetails);
const serverConfirmationetails: ServerToolCallConfirmationDetails = {
request: {
callId: outcome.callId,
name: outcome.name,
args: outcome.args,
},
details: outcome.confirmationDetails,
};
yield {
type: GeminiEventType.ToolCallConfirmation,
value: serverConfirmationetails,
};
} else {
const responsePart = this.buildFunctionResponse(outcome);
this.fnResponses.push(responsePart);
const responseInfo: ToolCallResponseInfo = {
callId: outcome.callId,
responsePart,
resultDisplay: outcome.result?.returnDisplay,
error: outcome.error,
};
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
}
}
}