Enable tools to cancel active execution.

- Plumbed abort signals through to tools
- Updated the shell tool to properly cancel active requests by killing the entire child process tree of the underlying shell process and then report that the shell itself was canceled.

Fixes https://b.corp.google.com/issues/416829935
This commit is contained in:
Taylor Mullen 2025-05-09 23:29:02 -07:00 committed by N. Taylor Mullen
parent 090198a7d6
commit 6b518dc9e4
15 changed files with 191 additions and 72 deletions

View File

@ -26,6 +26,7 @@ interface HandleAtCommandParams {
addItem: UseHistoryManagerReturn['addItem']; addItem: UseHistoryManagerReturn['addItem'];
setDebugMessage: React.Dispatch<React.SetStateAction<string>>; setDebugMessage: React.Dispatch<React.SetStateAction<string>>;
messageId: number; messageId: number;
signal: AbortSignal;
} }
interface HandleAtCommandResult { interface HandleAtCommandResult {
@ -90,6 +91,7 @@ export async function handleAtCommand({
addItem, addItem,
setDebugMessage, setDebugMessage,
messageId: userMessageTimestamp, messageId: userMessageTimestamp,
signal,
}: HandleAtCommandParams): Promise<HandleAtCommandResult> { }: HandleAtCommandParams): Promise<HandleAtCommandResult> {
const trimmedQuery = query.trim(); const trimmedQuery = query.trim();
const parsedCommand = parseAtCommand(trimmedQuery); const parsedCommand = parseAtCommand(trimmedQuery);
@ -163,7 +165,7 @@ export async function handleAtCommand({
let toolCallDisplay: IndividualToolCallDisplay; let toolCallDisplay: IndividualToolCallDisplay;
try { try {
const result = await readManyFilesTool.execute(toolArgs); const result = await readManyFilesTool.execute(toolArgs, signal);
const fileContent = result.llmContent || ''; const fileContent = result.llmContent || '';
toolCallDisplay = { toolCallDisplay = {

View File

@ -89,7 +89,7 @@ export const useGeminiStream = (
}, [config, addItem]); }, [config, addItem]);
useInput((_input, key) => { useInput((_input, key) => {
if (streamingState === StreamingState.Responding && key.escape) { if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort(); abortControllerRef.current?.abort();
} }
}); });
@ -104,6 +104,9 @@ export const useGeminiStream = (
setShowHelp(false); setShowHelp(false);
abortControllerRef.current ??= new AbortController();
const signal = abortControllerRef.current.signal;
if (typeof query === 'string') { if (typeof query === 'string') {
const trimmedQuery = query.trim(); const trimmedQuery = query.trim();
setDebugMessage(`User query: '${trimmedQuery}'`); setDebugMessage(`User query: '${trimmedQuery}'`);
@ -120,6 +123,7 @@ export const useGeminiStream = (
addItem, addItem,
setDebugMessage, setDebugMessage,
messageId: userMessageTimestamp, messageId: userMessageTimestamp,
signal,
}); });
if (!atCommandResult.shouldProceed) return; if (!atCommandResult.shouldProceed) return;
queryToSendToGemini = atCommandResult.processedQuery; queryToSendToGemini = atCommandResult.processedQuery;
@ -165,9 +169,6 @@ export const useGeminiStream = (
const chat = chatSessionRef.current; const chat = chatSessionRef.current;
try { try {
abortControllerRef.current = new AbortController();
const signal = abortControllerRef.current.signal;
const stream = client.sendMessageStream( const stream = client.sendMessageStream(
chat, chat,
queryToSendToGemini, queryToSendToGemini,
@ -294,7 +295,26 @@ export const useGeminiStream = (
} else if (event.type === ServerGeminiEventType.UserCancelled) { } else if (event.type === ServerGeminiEventType.UserCancelled) {
// Flush out existing pending history item. // Flush out existing pending history item.
if (pendingHistoryItemRef.current) { if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp); // If the pending item is a tool_group, update statuses to Canceled
if (pendingHistoryItemRef.current.type === 'tool_group') {
const updatedTools = pendingHistoryItemRef.current.tools.map(
(tool) => {
if (
tool.status === ToolCallStatus.Pending ||
tool.status === ToolCallStatus.Confirming ||
tool.status === ToolCallStatus.Executing
) {
return { ...tool, status: ToolCallStatus.Canceled };
}
return tool;
},
);
const pendingHistoryItem = pendingHistoryItemRef.current;
pendingHistoryItem.tools = updatedTools;
addItem(pendingHistoryItem, userMessageTimestamp);
} else {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
}
setPendingHistoryItem(null); setPendingHistoryItem(null);
} }
addItem( addItem(
@ -412,6 +432,59 @@ export const useGeminiStream = (
} }
if (outcome === ToolConfirmationOutcome.Cancel) { if (outcome === ToolConfirmationOutcome.Cancel) {
declineToolExecution(
'User rejected function call.',
ToolCallStatus.Error,
);
} else {
const tool = toolRegistry.getTool(request.name);
if (!tool) {
throw new Error(
`Tool "${request.name}" not found or is not registered.`,
);
}
try {
abortControllerRef.current = new AbortController();
const result = await tool.execute(
request.args,
abortControllerRef.current.signal,
);
if (abortControllerRef.current.signal.aborted) {
declineToolExecution(
result.llmContent,
ToolCallStatus.Canceled,
);
return;
}
const functionResponse: Part = {
functionResponse: {
name: request.name,
id: request.callId,
response: { output: result.llmContent },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
setStreamingState(StreamingState.Idle);
await submitQuery(functionResponse);
} finally {
abortControllerRef.current = null;
}
}
function declineToolExecution(
declineMessage: string,
status: ToolCallStatus,
) {
let resultDisplay: ToolResultDisplay | undefined; let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalConfirmationDetails) { if ('fileDiff' in originalConfirmationDetails) {
resultDisplay = { resultDisplay = {
@ -426,43 +499,19 @@ export const useGeminiStream = (
functionResponse: { functionResponse: {
id: request.callId, id: request.callId,
name: request.name, name: request.name,
response: { error: 'User rejected function call.' }, response: { error: declineMessage },
}, },
}; };
const responseInfo: ToolCallResponseInfo = { const responseInfo: ToolCallResponseInfo = {
callId: request.callId, callId: request.callId,
responsePart: functionResponse, responsePart: functionResponse,
resultDisplay, resultDisplay,
error: new Error('User rejected function call.'), error: new Error(declineMessage),
};
// Update UI to show cancellation/error
updateFunctionResponseUI(responseInfo, ToolCallStatus.Error);
setStreamingState(StreamingState.Idle);
} else {
const tool = toolRegistry.getTool(request.name);
if (!tool) {
throw new Error(
`Tool "${request.name}" not found or is not registered.`,
);
}
const result = await tool.execute(request.args);
const functionResponse: Part = {
functionResponse: {
name: request.name,
id: request.callId,
response: { output: result.llmContent },
},
}; };
const responseInfo: ToolCallResponseInfo = { // Update UI to show cancellation/error
callId: request.callId, updateFunctionResponseUI(responseInfo, status);
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
setStreamingState(StreamingState.Idle); setStreamingState(StreamingState.Idle);
await submitQuery(functionResponse);
} }
}; };

View File

@ -64,10 +64,13 @@ export class GeminiClient {
.getTool('read_many_files') as ReadManyFilesTool; .getTool('read_many_files') as ReadManyFilesTool;
if (readManyFilesTool) { if (readManyFilesTool) {
// Read all files in the target directory // Read all files in the target directory
const result = await readManyFilesTool.execute({ const result = await readManyFilesTool.execute(
paths: ['**/*'], // Read everything recursively {
useDefaultExcludes: true, // Use default excludes paths: ['**/*'], // Read everything recursively
}); useDefaultExcludes: true, // Use default excludes
},
AbortSignal.timeout(30000),
);
if (result.llmContent) { if (result.llmContent) {
initialParts.push({ initialParts.push({
text: `\n--- Full File Context ---\n${result.llmContent}`, text: `\n--- Full File Context ---\n${result.llmContent}`,

View File

@ -36,7 +36,10 @@ export interface ServerTool {
name: string; name: string;
schema: FunctionDeclaration; schema: FunctionDeclaration;
// The execute method signature might differ slightly or be wrapped // The execute method signature might differ slightly or be wrapped
execute(params: Record<string, unknown>): Promise<ToolResult>; execute(
params: Record<string, unknown>,
signal?: AbortSignal,
): Promise<ToolResult>;
shouldConfirmExecute( shouldConfirmExecute(
params: Record<string, unknown>, params: Record<string, unknown>,
): Promise<ToolCallConfirmationDetails | false>; ): Promise<ToolCallConfirmationDetails | false>;
@ -153,7 +156,7 @@ export class Turn {
if (confirmationDetails) { if (confirmationDetails) {
return { ...pendingToolCall, confirmationDetails }; return { ...pendingToolCall, confirmationDetails };
} }
const result = await tool.execute(pendingToolCall.args); const result = await tool.execute(pendingToolCall.args, signal);
return { return {
...pendingToolCall, ...pendingToolCall,
result, result,
@ -199,7 +202,11 @@ export class Turn {
resultDisplay: outcome.result?.returnDisplay, resultDisplay: outcome.result?.returnDisplay,
error: outcome.error, error: outcome.error,
}; };
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
// If aborted we're already yielding the user cancellations elsewhere.
if (!signal?.aborted) {
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
}
} }
} }

View File

@ -333,7 +333,10 @@ Expectation for parameters:
* @param params Parameters for the edit operation * @param params Parameters for the edit operation
* @returns Result of the edit operation * @returns Result of the edit operation
*/ */
async execute(params: EditToolParams): Promise<ToolResult> { async execute(
params: EditToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params); const validationError = this.validateParams(params);
if (validationError) { if (validationError) {
return { return {

View File

@ -138,7 +138,10 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
/** /**
* Executes the glob search with the given parameters * Executes the glob search with the given parameters
*/ */
async execute(params: GlobToolParams): Promise<ToolResult> { async execute(
params: GlobToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params); const validationError = this.validateToolParams(params);
if (validationError) { if (validationError) {
return { return {

View File

@ -166,7 +166,10 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
* @param params Parameters for the grep search * @param params Parameters for the grep search
* @returns Result of the grep search * @returns Result of the grep search
*/ */
async execute(params: GrepToolParams): Promise<ToolResult> { async execute(
params: GrepToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params); const validationError = this.validateToolParams(params);
if (validationError) { if (validationError) {
console.error( console.error(

View File

@ -184,7 +184,10 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
* @param params Parameters for the LS operation * @param params Parameters for the LS operation
* @returns Result of the LS operation * @returns Result of the LS operation
*/ */
async execute(params: LSToolParams): Promise<ToolResult> { async execute(
params: LSToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params); const validationError = this.validateToolParams(params);
if (validationError) { if (validationError) {
return this.errorResult( return this.errorResult(

View File

@ -193,7 +193,10 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
* @param params Parameters for the file reading * @param params Parameters for the file reading
* @returns Result with file contents * @returns Result with file contents
*/ */
async execute(params: ReadFileToolParams): Promise<ToolResult> { async execute(
params: ReadFileToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params); const validationError = this.validateToolParams(params);
if (validationError) { if (validationError) {
return { return {

View File

@ -237,7 +237,10 @@ Default excludes apply to common non-text files and large dependency directories
return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`; return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`;
} }
async execute(params: ReadManyFilesParams): Promise<ToolResult> { async execute(
params: ReadManyFilesParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params); const validationError = this.validateParams(params);
if (validationError) { if (validationError) {
return { return {

View File

@ -118,7 +118,10 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
return confirmationDetails; return confirmationDetails;
} }
async execute(params: ShellToolParams): Promise<ToolResult> { async execute(
params: ShellToolParams,
abortSignal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params); const validationError = this.validateToolParams(params);
if (validationError) { if (validationError) {
return { return {
@ -174,18 +177,38 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
}); });
let code: number | null = null; let code: number | null = null;
let signal: NodeJS.Signals | null = null; let processSignal: NodeJS.Signals | null = null;
shell.on( const closeHandler = (
'close', _code: number | null,
(_code: number | null, _signal: NodeJS.Signals | null) => { _signal: NodeJS.Signals | null,
code = _code; ) => {
signal = _signal; code = _code;
}, processSignal = _signal;
); };
shell.on('close', closeHandler);
const abortHandler = () => {
if (shell.pid) {
try {
// Kill the entire process group
process.kill(-shell.pid, 'SIGTERM');
} catch (_e) {
// Fallback to killing the main process if group kill fails
try {
shell.kill('SIGKILL'); // or 'SIGTERM'
} catch (_killError) {
// Ignore errors if the process is already dead
}
}
}
};
abortSignal.addEventListener('abort', abortHandler);
// wait for the shell to exit // wait for the shell to exit
await new Promise((resolve) => shell.on('close', resolve)); await new Promise((resolve) => shell.on('close', resolve));
abortSignal.removeEventListener('abort', abortHandler);
// parse pids (pgrep output) from temporary file and remove it // parse pids (pgrep output) from temporary file and remove it
const backgroundPIDs: number[] = []; const backgroundPIDs: number[] = [];
if (fs.existsSync(tempFilePath)) { if (fs.existsSync(tempFilePath)) {
@ -205,19 +228,26 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
} }
fs.unlinkSync(tempFilePath); fs.unlinkSync(tempFilePath);
} else { } else {
console.error('missing pgrep output'); if (!abortSignal.aborted) {
console.error('missing pgrep output');
}
} }
const llmContent = [ let llmContent = '';
`Command: ${params.command}`, if (abortSignal.aborted) {
`Directory: ${params.directory || '(root)'}`, llmContent = 'Command did not complete, it was cancelled by the user';
`Stdout: ${stdout || '(empty)'}`, } else {
`Stderr: ${stderr || '(empty)'}`, llmContent = [
`Error: ${error ?? '(none)'}`, `Command: ${params.command}`,
`Exit Code: ${code ?? '(none)'}`, `Directory: ${params.directory || '(root)'}`,
`Signal: ${signal ?? '(none)'}`, `Stdout: ${stdout || '(empty)'}`,
`Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`, `Stderr: ${stderr || '(empty)'}`,
].join('\n'); `Error: ${error ?? '(none)'}`,
`Exit Code: ${code ?? '(none)'}`,
`Signal: ${processSignal ?? '(none)'}`,
`Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
].join('\n');
}
const returnDisplay = this.config.getDebugMode() ? llmContent : output; const returnDisplay = this.config.getDebugMode() ? llmContent : output;

View File

@ -265,7 +265,10 @@ Use this tool for running build steps (\`npm install\`, \`make\`), linters (\`es
return confirmationDetails; return confirmationDetails;
} }
async execute(params: TerminalToolParams): Promise<ToolResult> { async execute(
params: TerminalToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params); const validationError = this.validateToolParams(params);
if (validationError) { if (validationError) {
return { return {

View File

@ -64,7 +64,7 @@ export interface Tool<
* @param params Parameters for the tool execution * @param params Parameters for the tool execution
* @returns Result of the tool execution * @returns Result of the tool execution
*/ */
execute(params: TParams): Promise<TResult>; execute(params: TParams, signal: AbortSignal): Promise<TResult>;
} }
/** /**
@ -141,9 +141,10 @@ export abstract class BaseTool<
* Abstract method to execute the tool with the given parameters * Abstract method to execute the tool with the given parameters
* Must be implemented by derived classes * Must be implemented by derived classes
* @param params Parameters for the tool execution * @param params Parameters for the tool execution
* @param signal AbortSignal for tool cancellation
* @returns Result of the tool execution * @returns Result of the tool execution
*/ */
abstract execute(params: TParams): Promise<TResult>; abstract execute(params: TParams, signal: AbortSignal): Promise<TResult>;
} }
export interface ToolResult { export interface ToolResult {

View File

@ -70,7 +70,10 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
return `Fetching content from ${displayUrl}`; return `Fetching content from ${displayUrl}`;
} }
async execute(params: WebFetchToolParams): Promise<ToolResult> { async execute(
params: WebFetchToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params); const validationError = this.validateParams(params);
if (validationError) { if (validationError) {
return { return {

View File

@ -150,7 +150,10 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
return confirmationDetails; return confirmationDetails;
} }
async execute(params: WriteFileToolParams): Promise<ToolResult> { async execute(
params: WriteFileToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params); const validationError = this.validateParams(params);
if (validationError) { if (validationError) {
return { return {