From 6b518dc9e4c601c0108768932dc1450c036075fd Mon Sep 17 00:00:00 2001 From: Taylor Mullen Date: Fri, 9 May 2025 23:29:02 -0700 Subject: [PATCH] Enable tools to cancel active execution. - Plumbed abort signals through to tools - Updated the shell tool to properly cancel active requests by killing the entire child process tree of the underlying shell process and then report that the shell itself was canceled. Fixes https://b.corp.google.com/issues/416829935 --- .../cli/src/ui/hooks/atCommandProcessor.ts | 4 +- packages/cli/src/ui/hooks/useGeminiStream.ts | 115 +++++++++++++----- packages/server/src/core/client.ts | 11 +- packages/server/src/core/turn.ts | 13 +- packages/server/src/tools/edit.ts | 5 +- packages/server/src/tools/glob.ts | 5 +- packages/server/src/tools/grep.ts | 5 +- packages/server/src/tools/ls.ts | 5 +- packages/server/src/tools/read-file.ts | 5 +- packages/server/src/tools/read-many-files.ts | 5 +- packages/server/src/tools/shell.ts | 70 ++++++++--- packages/server/src/tools/terminal.ts | 5 +- packages/server/src/tools/tools.ts | 5 +- packages/server/src/tools/web-fetch.ts | 5 +- packages/server/src/tools/write-file.ts | 5 +- 15 files changed, 191 insertions(+), 72 deletions(-) diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 5ffa5383..a13a7d36 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -26,6 +26,7 @@ interface HandleAtCommandParams { addItem: UseHistoryManagerReturn['addItem']; setDebugMessage: React.Dispatch>; messageId: number; + signal: AbortSignal; } interface HandleAtCommandResult { @@ -90,6 +91,7 @@ export async function handleAtCommand({ addItem, setDebugMessage, messageId: userMessageTimestamp, + signal, }: HandleAtCommandParams): Promise { const trimmedQuery = query.trim(); const parsedCommand = parseAtCommand(trimmedQuery); @@ -163,7 +165,7 @@ export async function handleAtCommand({ let toolCallDisplay: IndividualToolCallDisplay; try { - const result = await readManyFilesTool.execute(toolArgs); + const result = await readManyFilesTool.execute(toolArgs, signal); const fileContent = result.llmContent || ''; toolCallDisplay = { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 3f8cee40..e86ae0b9 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -89,7 +89,7 @@ export const useGeminiStream = ( }, [config, addItem]); useInput((_input, key) => { - if (streamingState === StreamingState.Responding && key.escape) { + if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); } }); @@ -104,6 +104,9 @@ export const useGeminiStream = ( setShowHelp(false); + abortControllerRef.current ??= new AbortController(); + const signal = abortControllerRef.current.signal; + if (typeof query === 'string') { const trimmedQuery = query.trim(); setDebugMessage(`User query: '${trimmedQuery}'`); @@ -120,6 +123,7 @@ export const useGeminiStream = ( addItem, setDebugMessage, messageId: userMessageTimestamp, + signal, }); if (!atCommandResult.shouldProceed) return; queryToSendToGemini = atCommandResult.processedQuery; @@ -165,9 +169,6 @@ export const useGeminiStream = ( const chat = chatSessionRef.current; try { - abortControllerRef.current = new AbortController(); - const signal = abortControllerRef.current.signal; - const stream = client.sendMessageStream( chat, queryToSendToGemini, @@ -294,7 +295,26 @@ export const useGeminiStream = ( } else if (event.type === ServerGeminiEventType.UserCancelled) { // Flush out existing pending history item. if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); + // If the pending item is a tool_group, update statuses to Canceled + if (pendingHistoryItemRef.current.type === 'tool_group') { + const updatedTools = pendingHistoryItemRef.current.tools.map( + (tool) => { + if ( + tool.status === ToolCallStatus.Pending || + tool.status === ToolCallStatus.Confirming || + tool.status === ToolCallStatus.Executing + ) { + return { ...tool, status: ToolCallStatus.Canceled }; + } + return tool; + }, + ); + const pendingHistoryItem = pendingHistoryItemRef.current; + pendingHistoryItem.tools = updatedTools; + addItem(pendingHistoryItem, userMessageTimestamp); + } else { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } setPendingHistoryItem(null); } addItem( @@ -412,6 +432,59 @@ export const useGeminiStream = ( } if (outcome === ToolConfirmationOutcome.Cancel) { + declineToolExecution( + 'User rejected function call.', + ToolCallStatus.Error, + ); + } else { + const tool = toolRegistry.getTool(request.name); + if (!tool) { + throw new Error( + `Tool "${request.name}" not found or is not registered.`, + ); + } + + try { + abortControllerRef.current = new AbortController(); + const result = await tool.execute( + request.args, + abortControllerRef.current.signal, + ); + + if (abortControllerRef.current.signal.aborted) { + declineToolExecution( + result.llmContent, + ToolCallStatus.Canceled, + ); + return; + } + + const functionResponse: Part = { + functionResponse: { + name: request.name, + id: request.callId, + response: { output: result.llmContent }, + }, + }; + + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + setStreamingState(StreamingState.Idle); + await submitQuery(functionResponse); + } finally { + abortControllerRef.current = null; + } + } + + function declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + ) { let resultDisplay: ToolResultDisplay | undefined; if ('fileDiff' in originalConfirmationDetails) { resultDisplay = { @@ -426,43 +499,19 @@ export const useGeminiStream = ( functionResponse: { id: request.callId, name: request.name, - response: { error: 'User rejected function call.' }, + response: { error: declineMessage }, }, }; const responseInfo: ToolCallResponseInfo = { callId: request.callId, responsePart: functionResponse, resultDisplay, - error: new Error('User rejected function call.'), - }; - // Update UI to show cancellation/error - updateFunctionResponseUI(responseInfo, ToolCallStatus.Error); - setStreamingState(StreamingState.Idle); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - const result = await tool.execute(request.args); - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, + error: new Error(declineMessage), }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + // Update UI to show cancellation/error + updateFunctionResponseUI(responseInfo, status); setStreamingState(StreamingState.Idle); - await submitQuery(functionResponse); } }; diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 904e944c..46af465a 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -64,10 +64,13 @@ export class GeminiClient { .getTool('read_many_files') as ReadManyFilesTool; if (readManyFilesTool) { // Read all files in the target directory - const result = await readManyFilesTool.execute({ - paths: ['**/*'], // Read everything recursively - useDefaultExcludes: true, // Use default excludes - }); + const result = await readManyFilesTool.execute( + { + paths: ['**/*'], // Read everything recursively + useDefaultExcludes: true, // Use default excludes + }, + AbortSignal.timeout(30000), + ); if (result.llmContent) { initialParts.push({ text: `\n--- Full File Context ---\n${result.llmContent}`, diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index 7d8bf7b6..62219938 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -36,7 +36,10 @@ export interface ServerTool { name: string; schema: FunctionDeclaration; // The execute method signature might differ slightly or be wrapped - execute(params: Record): Promise; + execute( + params: Record, + signal?: AbortSignal, + ): Promise; shouldConfirmExecute( params: Record, ): Promise; @@ -153,7 +156,7 @@ export class Turn { if (confirmationDetails) { return { ...pendingToolCall, confirmationDetails }; } - const result = await tool.execute(pendingToolCall.args); + const result = await tool.execute(pendingToolCall.args, signal); return { ...pendingToolCall, result, @@ -199,7 +202,11 @@ export class Turn { resultDisplay: outcome.result?.returnDisplay, error: outcome.error, }; - yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; + + // If aborted we're already yielding the user cancellations elsewhere. + if (!signal?.aborted) { + yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; + } } } diff --git a/packages/server/src/tools/edit.ts b/packages/server/src/tools/edit.ts index c40b9e44..fd57d97d 100644 --- a/packages/server/src/tools/edit.ts +++ b/packages/server/src/tools/edit.ts @@ -333,7 +333,10 @@ Expectation for parameters: * @param params Parameters for the edit operation * @returns Result of the edit operation */ - async execute(params: EditToolParams): Promise { + async execute( + params: EditToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/glob.ts b/packages/server/src/tools/glob.ts index 9e7df0e8..b1b9d0cf 100644 --- a/packages/server/src/tools/glob.ts +++ b/packages/server/src/tools/glob.ts @@ -138,7 +138,10 @@ export class GlobTool extends BaseTool { /** * Executes the glob search with the given parameters */ - async execute(params: GlobToolParams): Promise { + async execute( + params: GlobToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/grep.ts b/packages/server/src/tools/grep.ts index e3253ecf..54391832 100644 --- a/packages/server/src/tools/grep.ts +++ b/packages/server/src/tools/grep.ts @@ -166,7 +166,10 @@ export class GrepTool extends BaseTool { * @param params Parameters for the grep search * @returns Result of the grep search */ - async execute(params: GrepToolParams): Promise { + async execute( + params: GrepToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { console.error( diff --git a/packages/server/src/tools/ls.ts b/packages/server/src/tools/ls.ts index 01da5121..fea95187 100644 --- a/packages/server/src/tools/ls.ts +++ b/packages/server/src/tools/ls.ts @@ -184,7 +184,10 @@ export class LSTool extends BaseTool { * @param params Parameters for the LS operation * @returns Result of the LS operation */ - async execute(params: LSToolParams): Promise { + async execute( + params: LSToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return this.errorResult( diff --git a/packages/server/src/tools/read-file.ts b/packages/server/src/tools/read-file.ts index 598b4691..de09161d 100644 --- a/packages/server/src/tools/read-file.ts +++ b/packages/server/src/tools/read-file.ts @@ -193,7 +193,10 @@ export class ReadFileTool extends BaseTool { * @param params Parameters for the file reading * @returns Result with file contents */ - async execute(params: ReadFileToolParams): Promise { + async execute( + params: ReadFileToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/read-many-files.ts b/packages/server/src/tools/read-many-files.ts index 0b4b090d..44882e44 100644 --- a/packages/server/src/tools/read-many-files.ts +++ b/packages/server/src/tools/read-many-files.ts @@ -237,7 +237,10 @@ Default excludes apply to common non-text files and large dependency directories return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`; } - async execute(params: ReadManyFilesParams): Promise { + async execute( + params: ReadManyFilesParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/shell.ts b/packages/server/src/tools/shell.ts index fd8a6b1a..7851b76a 100644 --- a/packages/server/src/tools/shell.ts +++ b/packages/server/src/tools/shell.ts @@ -118,7 +118,10 @@ export class ShellTool extends BaseTool { return confirmationDetails; } - async execute(params: ShellToolParams): Promise { + async execute( + params: ShellToolParams, + abortSignal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { @@ -174,18 +177,38 @@ export class ShellTool extends BaseTool { }); let code: number | null = null; - let signal: NodeJS.Signals | null = null; - shell.on( - 'close', - (_code: number | null, _signal: NodeJS.Signals | null) => { - code = _code; - signal = _signal; - }, - ); + let processSignal: NodeJS.Signals | null = null; + const closeHandler = ( + _code: number | null, + _signal: NodeJS.Signals | null, + ) => { + code = _code; + processSignal = _signal; + }; + shell.on('close', closeHandler); + + const abortHandler = () => { + if (shell.pid) { + try { + // Kill the entire process group + process.kill(-shell.pid, 'SIGTERM'); + } catch (_e) { + // Fallback to killing the main process if group kill fails + try { + shell.kill('SIGKILL'); // or 'SIGTERM' + } catch (_killError) { + // Ignore errors if the process is already dead + } + } + } + }; + abortSignal.addEventListener('abort', abortHandler); // wait for the shell to exit await new Promise((resolve) => shell.on('close', resolve)); + abortSignal.removeEventListener('abort', abortHandler); + // parse pids (pgrep output) from temporary file and remove it const backgroundPIDs: number[] = []; if (fs.existsSync(tempFilePath)) { @@ -205,19 +228,26 @@ export class ShellTool extends BaseTool { } fs.unlinkSync(tempFilePath); } else { - console.error('missing pgrep output'); + if (!abortSignal.aborted) { + console.error('missing pgrep output'); + } } - const llmContent = [ - `Command: ${params.command}`, - `Directory: ${params.directory || '(root)'}`, - `Stdout: ${stdout || '(empty)'}`, - `Stderr: ${stderr || '(empty)'}`, - `Error: ${error ?? '(none)'}`, - `Exit Code: ${code ?? '(none)'}`, - `Signal: ${signal ?? '(none)'}`, - `Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`, - ].join('\n'); + let llmContent = ''; + if (abortSignal.aborted) { + llmContent = 'Command did not complete, it was cancelled by the user'; + } else { + llmContent = [ + `Command: ${params.command}`, + `Directory: ${params.directory || '(root)'}`, + `Stdout: ${stdout || '(empty)'}`, + `Stderr: ${stderr || '(empty)'}`, + `Error: ${error ?? '(none)'}`, + `Exit Code: ${code ?? '(none)'}`, + `Signal: ${processSignal ?? '(none)'}`, + `Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`, + ].join('\n'); + } const returnDisplay = this.config.getDebugMode() ? llmContent : output; diff --git a/packages/server/src/tools/terminal.ts b/packages/server/src/tools/terminal.ts index 7320cfb2..af558fb0 100644 --- a/packages/server/src/tools/terminal.ts +++ b/packages/server/src/tools/terminal.ts @@ -265,7 +265,10 @@ Use this tool for running build steps (\`npm install\`, \`make\`), linters (\`es return confirmationDetails; } - async execute(params: TerminalToolParams): Promise { + async execute( + params: TerminalToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/tools.ts b/packages/server/src/tools/tools.ts index ac04450d..7bb05a95 100644 --- a/packages/server/src/tools/tools.ts +++ b/packages/server/src/tools/tools.ts @@ -64,7 +64,7 @@ export interface Tool< * @param params Parameters for the tool execution * @returns Result of the tool execution */ - execute(params: TParams): Promise; + execute(params: TParams, signal: AbortSignal): Promise; } /** @@ -141,9 +141,10 @@ export abstract class BaseTool< * Abstract method to execute the tool with the given parameters * Must be implemented by derived classes * @param params Parameters for the tool execution + * @param signal AbortSignal for tool cancellation * @returns Result of the tool execution */ - abstract execute(params: TParams): Promise; + abstract execute(params: TParams, signal: AbortSignal): Promise; } export interface ToolResult { diff --git a/packages/server/src/tools/web-fetch.ts b/packages/server/src/tools/web-fetch.ts index 12584231..62ca2162 100644 --- a/packages/server/src/tools/web-fetch.ts +++ b/packages/server/src/tools/web-fetch.ts @@ -70,7 +70,10 @@ export class WebFetchTool extends BaseTool { return `Fetching content from ${displayUrl}`; } - async execute(params: WebFetchToolParams): Promise { + async execute( + params: WebFetchToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/write-file.ts b/packages/server/src/tools/write-file.ts index c9a47296..1f4c0d94 100644 --- a/packages/server/src/tools/write-file.ts +++ b/packages/server/src/tools/write-file.ts @@ -150,7 +150,10 @@ export class WriteFileTool extends BaseTool { return confirmationDetails; } - async execute(params: WriteFileToolParams): Promise { + async execute( + params: WriteFileToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return {