diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 5ffa5383..a13a7d36 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -26,6 +26,7 @@ interface HandleAtCommandParams { addItem: UseHistoryManagerReturn['addItem']; setDebugMessage: React.Dispatch>; messageId: number; + signal: AbortSignal; } interface HandleAtCommandResult { @@ -90,6 +91,7 @@ export async function handleAtCommand({ addItem, setDebugMessage, messageId: userMessageTimestamp, + signal, }: HandleAtCommandParams): Promise { const trimmedQuery = query.trim(); const parsedCommand = parseAtCommand(trimmedQuery); @@ -163,7 +165,7 @@ export async function handleAtCommand({ let toolCallDisplay: IndividualToolCallDisplay; try { - const result = await readManyFilesTool.execute(toolArgs); + const result = await readManyFilesTool.execute(toolArgs, signal); const fileContent = result.llmContent || ''; toolCallDisplay = { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 3f8cee40..e86ae0b9 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -89,7 +89,7 @@ export const useGeminiStream = ( }, [config, addItem]); useInput((_input, key) => { - if (streamingState === StreamingState.Responding && key.escape) { + if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); } }); @@ -104,6 +104,9 @@ export const useGeminiStream = ( setShowHelp(false); + abortControllerRef.current ??= new AbortController(); + const signal = abortControllerRef.current.signal; + if (typeof query === 'string') { const trimmedQuery = query.trim(); setDebugMessage(`User query: '${trimmedQuery}'`); @@ -120,6 +123,7 @@ export const useGeminiStream = ( addItem, setDebugMessage, messageId: userMessageTimestamp, + signal, }); if (!atCommandResult.shouldProceed) return; queryToSendToGemini = atCommandResult.processedQuery; @@ -165,9 +169,6 @@ export const useGeminiStream = ( const chat = chatSessionRef.current; try { - abortControllerRef.current = new AbortController(); - const signal = abortControllerRef.current.signal; - const stream = client.sendMessageStream( chat, queryToSendToGemini, @@ -294,7 +295,26 @@ export const useGeminiStream = ( } else if (event.type === ServerGeminiEventType.UserCancelled) { // Flush out existing pending history item. if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); + // If the pending item is a tool_group, update statuses to Canceled + if (pendingHistoryItemRef.current.type === 'tool_group') { + const updatedTools = pendingHistoryItemRef.current.tools.map( + (tool) => { + if ( + tool.status === ToolCallStatus.Pending || + tool.status === ToolCallStatus.Confirming || + tool.status === ToolCallStatus.Executing + ) { + return { ...tool, status: ToolCallStatus.Canceled }; + } + return tool; + }, + ); + const pendingHistoryItem = pendingHistoryItemRef.current; + pendingHistoryItem.tools = updatedTools; + addItem(pendingHistoryItem, userMessageTimestamp); + } else { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } setPendingHistoryItem(null); } addItem( @@ -412,6 +432,59 @@ export const useGeminiStream = ( } if (outcome === ToolConfirmationOutcome.Cancel) { + declineToolExecution( + 'User rejected function call.', + ToolCallStatus.Error, + ); + } else { + const tool = toolRegistry.getTool(request.name); + if (!tool) { + throw new Error( + `Tool "${request.name}" not found or is not registered.`, + ); + } + + try { + abortControllerRef.current = new AbortController(); + const result = await tool.execute( + request.args, + abortControllerRef.current.signal, + ); + + if (abortControllerRef.current.signal.aborted) { + declineToolExecution( + result.llmContent, + ToolCallStatus.Canceled, + ); + return; + } + + const functionResponse: Part = { + functionResponse: { + name: request.name, + id: request.callId, + response: { output: result.llmContent }, + }, + }; + + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + setStreamingState(StreamingState.Idle); + await submitQuery(functionResponse); + } finally { + abortControllerRef.current = null; + } + } + + function declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + ) { let resultDisplay: ToolResultDisplay | undefined; if ('fileDiff' in originalConfirmationDetails) { resultDisplay = { @@ -426,43 +499,19 @@ export const useGeminiStream = ( functionResponse: { id: request.callId, name: request.name, - response: { error: 'User rejected function call.' }, + response: { error: declineMessage }, }, }; const responseInfo: ToolCallResponseInfo = { callId: request.callId, responsePart: functionResponse, resultDisplay, - error: new Error('User rejected function call.'), - }; - // Update UI to show cancellation/error - updateFunctionResponseUI(responseInfo, ToolCallStatus.Error); - setStreamingState(StreamingState.Idle); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - const result = await tool.execute(request.args); - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, + error: new Error(declineMessage), }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + // Update UI to show cancellation/error + updateFunctionResponseUI(responseInfo, status); setStreamingState(StreamingState.Idle); - await submitQuery(functionResponse); } }; diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 904e944c..46af465a 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -64,10 +64,13 @@ export class GeminiClient { .getTool('read_many_files') as ReadManyFilesTool; if (readManyFilesTool) { // Read all files in the target directory - const result = await readManyFilesTool.execute({ - paths: ['**/*'], // Read everything recursively - useDefaultExcludes: true, // Use default excludes - }); + const result = await readManyFilesTool.execute( + { + paths: ['**/*'], // Read everything recursively + useDefaultExcludes: true, // Use default excludes + }, + AbortSignal.timeout(30000), + ); if (result.llmContent) { initialParts.push({ text: `\n--- Full File Context ---\n${result.llmContent}`, diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index 7d8bf7b6..62219938 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -36,7 +36,10 @@ export interface ServerTool { name: string; schema: FunctionDeclaration; // The execute method signature might differ slightly or be wrapped - execute(params: Record): Promise; + execute( + params: Record, + signal?: AbortSignal, + ): Promise; shouldConfirmExecute( params: Record, ): Promise; @@ -153,7 +156,7 @@ export class Turn { if (confirmationDetails) { return { ...pendingToolCall, confirmationDetails }; } - const result = await tool.execute(pendingToolCall.args); + const result = await tool.execute(pendingToolCall.args, signal); return { ...pendingToolCall, result, @@ -199,7 +202,11 @@ export class Turn { resultDisplay: outcome.result?.returnDisplay, error: outcome.error, }; - yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; + + // If aborted we're already yielding the user cancellations elsewhere. + if (!signal?.aborted) { + yield { type: GeminiEventType.ToolCallResponse, value: responseInfo }; + } } } diff --git a/packages/server/src/tools/edit.ts b/packages/server/src/tools/edit.ts index c40b9e44..fd57d97d 100644 --- a/packages/server/src/tools/edit.ts +++ b/packages/server/src/tools/edit.ts @@ -333,7 +333,10 @@ Expectation for parameters: * @param params Parameters for the edit operation * @returns Result of the edit operation */ - async execute(params: EditToolParams): Promise { + async execute( + params: EditToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/glob.ts b/packages/server/src/tools/glob.ts index 9e7df0e8..b1b9d0cf 100644 --- a/packages/server/src/tools/glob.ts +++ b/packages/server/src/tools/glob.ts @@ -138,7 +138,10 @@ export class GlobTool extends BaseTool { /** * Executes the glob search with the given parameters */ - async execute(params: GlobToolParams): Promise { + async execute( + params: GlobToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/grep.ts b/packages/server/src/tools/grep.ts index e3253ecf..54391832 100644 --- a/packages/server/src/tools/grep.ts +++ b/packages/server/src/tools/grep.ts @@ -166,7 +166,10 @@ export class GrepTool extends BaseTool { * @param params Parameters for the grep search * @returns Result of the grep search */ - async execute(params: GrepToolParams): Promise { + async execute( + params: GrepToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { console.error( diff --git a/packages/server/src/tools/ls.ts b/packages/server/src/tools/ls.ts index 01da5121..fea95187 100644 --- a/packages/server/src/tools/ls.ts +++ b/packages/server/src/tools/ls.ts @@ -184,7 +184,10 @@ export class LSTool extends BaseTool { * @param params Parameters for the LS operation * @returns Result of the LS operation */ - async execute(params: LSToolParams): Promise { + async execute( + params: LSToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return this.errorResult( diff --git a/packages/server/src/tools/read-file.ts b/packages/server/src/tools/read-file.ts index 598b4691..de09161d 100644 --- a/packages/server/src/tools/read-file.ts +++ b/packages/server/src/tools/read-file.ts @@ -193,7 +193,10 @@ export class ReadFileTool extends BaseTool { * @param params Parameters for the file reading * @returns Result with file contents */ - async execute(params: ReadFileToolParams): Promise { + async execute( + params: ReadFileToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/read-many-files.ts b/packages/server/src/tools/read-many-files.ts index 0b4b090d..44882e44 100644 --- a/packages/server/src/tools/read-many-files.ts +++ b/packages/server/src/tools/read-many-files.ts @@ -237,7 +237,10 @@ Default excludes apply to common non-text files and large dependency directories return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`; } - async execute(params: ReadManyFilesParams): Promise { + async execute( + params: ReadManyFilesParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/shell.ts b/packages/server/src/tools/shell.ts index fd8a6b1a..7851b76a 100644 --- a/packages/server/src/tools/shell.ts +++ b/packages/server/src/tools/shell.ts @@ -118,7 +118,10 @@ export class ShellTool extends BaseTool { return confirmationDetails; } - async execute(params: ShellToolParams): Promise { + async execute( + params: ShellToolParams, + abortSignal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { @@ -174,18 +177,38 @@ export class ShellTool extends BaseTool { }); let code: number | null = null; - let signal: NodeJS.Signals | null = null; - shell.on( - 'close', - (_code: number | null, _signal: NodeJS.Signals | null) => { - code = _code; - signal = _signal; - }, - ); + let processSignal: NodeJS.Signals | null = null; + const closeHandler = ( + _code: number | null, + _signal: NodeJS.Signals | null, + ) => { + code = _code; + processSignal = _signal; + }; + shell.on('close', closeHandler); + + const abortHandler = () => { + if (shell.pid) { + try { + // Kill the entire process group + process.kill(-shell.pid, 'SIGTERM'); + } catch (_e) { + // Fallback to killing the main process if group kill fails + try { + shell.kill('SIGKILL'); // or 'SIGTERM' + } catch (_killError) { + // Ignore errors if the process is already dead + } + } + } + }; + abortSignal.addEventListener('abort', abortHandler); // wait for the shell to exit await new Promise((resolve) => shell.on('close', resolve)); + abortSignal.removeEventListener('abort', abortHandler); + // parse pids (pgrep output) from temporary file and remove it const backgroundPIDs: number[] = []; if (fs.existsSync(tempFilePath)) { @@ -205,19 +228,26 @@ export class ShellTool extends BaseTool { } fs.unlinkSync(tempFilePath); } else { - console.error('missing pgrep output'); + if (!abortSignal.aborted) { + console.error('missing pgrep output'); + } } - const llmContent = [ - `Command: ${params.command}`, - `Directory: ${params.directory || '(root)'}`, - `Stdout: ${stdout || '(empty)'}`, - `Stderr: ${stderr || '(empty)'}`, - `Error: ${error ?? '(none)'}`, - `Exit Code: ${code ?? '(none)'}`, - `Signal: ${signal ?? '(none)'}`, - `Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`, - ].join('\n'); + let llmContent = ''; + if (abortSignal.aborted) { + llmContent = 'Command did not complete, it was cancelled by the user'; + } else { + llmContent = [ + `Command: ${params.command}`, + `Directory: ${params.directory || '(root)'}`, + `Stdout: ${stdout || '(empty)'}`, + `Stderr: ${stderr || '(empty)'}`, + `Error: ${error ?? '(none)'}`, + `Exit Code: ${code ?? '(none)'}`, + `Signal: ${processSignal ?? '(none)'}`, + `Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`, + ].join('\n'); + } const returnDisplay = this.config.getDebugMode() ? llmContent : output; diff --git a/packages/server/src/tools/terminal.ts b/packages/server/src/tools/terminal.ts index 7320cfb2..af558fb0 100644 --- a/packages/server/src/tools/terminal.ts +++ b/packages/server/src/tools/terminal.ts @@ -265,7 +265,10 @@ Use this tool for running build steps (\`npm install\`, \`make\`), linters (\`es return confirmationDetails; } - async execute(params: TerminalToolParams): Promise { + async execute( + params: TerminalToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateToolParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/tools.ts b/packages/server/src/tools/tools.ts index ac04450d..7bb05a95 100644 --- a/packages/server/src/tools/tools.ts +++ b/packages/server/src/tools/tools.ts @@ -64,7 +64,7 @@ export interface Tool< * @param params Parameters for the tool execution * @returns Result of the tool execution */ - execute(params: TParams): Promise; + execute(params: TParams, signal: AbortSignal): Promise; } /** @@ -141,9 +141,10 @@ export abstract class BaseTool< * Abstract method to execute the tool with the given parameters * Must be implemented by derived classes * @param params Parameters for the tool execution + * @param signal AbortSignal for tool cancellation * @returns Result of the tool execution */ - abstract execute(params: TParams): Promise; + abstract execute(params: TParams, signal: AbortSignal): Promise; } export interface ToolResult { diff --git a/packages/server/src/tools/web-fetch.ts b/packages/server/src/tools/web-fetch.ts index 12584231..62ca2162 100644 --- a/packages/server/src/tools/web-fetch.ts +++ b/packages/server/src/tools/web-fetch.ts @@ -70,7 +70,10 @@ export class WebFetchTool extends BaseTool { return `Fetching content from ${displayUrl}`; } - async execute(params: WebFetchToolParams): Promise { + async execute( + params: WebFetchToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { diff --git a/packages/server/src/tools/write-file.ts b/packages/server/src/tools/write-file.ts index c9a47296..1f4c0d94 100644 --- a/packages/server/src/tools/write-file.ts +++ b/packages/server/src/tools/write-file.ts @@ -150,7 +150,10 @@ export class WriteFileTool extends BaseTool { return confirmationDetails; } - async execute(params: WriteFileToolParams): Promise { + async execute( + params: WriteFileToolParams, + _signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return {