Enable tools to cancel active execution.

- Plumbed abort signals through to tools
- Updated the shell tool to properly cancel active requests by killing the entire child process tree of the underlying shell process and then report that the shell itself was canceled.

Fixes https://b.corp.google.com/issues/416829935
This commit is contained in:
Taylor Mullen 2025-05-09 23:29:02 -07:00 committed by N. Taylor Mullen
parent 090198a7d6
commit 6b518dc9e4
15 changed files with 191 additions and 72 deletions

View File

@ -26,6 +26,7 @@ interface HandleAtCommandParams {
addItem: UseHistoryManagerReturn['addItem'];
setDebugMessage: React.Dispatch<React.SetStateAction<string>>;
messageId: number;
signal: AbortSignal;
}
interface HandleAtCommandResult {
@ -90,6 +91,7 @@ export async function handleAtCommand({
addItem,
setDebugMessage,
messageId: userMessageTimestamp,
signal,
}: HandleAtCommandParams): Promise<HandleAtCommandResult> {
const trimmedQuery = query.trim();
const parsedCommand = parseAtCommand(trimmedQuery);
@ -163,7 +165,7 @@ export async function handleAtCommand({
let toolCallDisplay: IndividualToolCallDisplay;
try {
const result = await readManyFilesTool.execute(toolArgs);
const result = await readManyFilesTool.execute(toolArgs, signal);
const fileContent = result.llmContent || '';
toolCallDisplay = {

View File

@ -89,7 +89,7 @@ export const useGeminiStream = (
}, [config, addItem]);
useInput((_input, key) => {
if (streamingState === StreamingState.Responding && key.escape) {
if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
}
});
@ -104,6 +104,9 @@ export const useGeminiStream = (
setShowHelp(false);
abortControllerRef.current ??= new AbortController();
const signal = abortControllerRef.current.signal;
if (typeof query === 'string') {
const trimmedQuery = query.trim();
setDebugMessage(`User query: '${trimmedQuery}'`);
@ -120,6 +123,7 @@ export const useGeminiStream = (
addItem,
setDebugMessage,
messageId: userMessageTimestamp,
signal,
});
if (!atCommandResult.shouldProceed) return;
queryToSendToGemini = atCommandResult.processedQuery;
@ -165,9 +169,6 @@ export const useGeminiStream = (
const chat = chatSessionRef.current;
try {
abortControllerRef.current = new AbortController();
const signal = abortControllerRef.current.signal;
const stream = client.sendMessageStream(
chat,
queryToSendToGemini,
@ -294,7 +295,26 @@ export const useGeminiStream = (
} else if (event.type === ServerGeminiEventType.UserCancelled) {
// Flush out existing pending history item.
if (pendingHistoryItemRef.current) {
// If the pending item is a tool_group, update statuses to Canceled
if (pendingHistoryItemRef.current.type === 'tool_group') {
const updatedTools = pendingHistoryItemRef.current.tools.map(
(tool) => {
if (
tool.status === ToolCallStatus.Pending ||
tool.status === ToolCallStatus.Confirming ||
tool.status === ToolCallStatus.Executing
) {
return { ...tool, status: ToolCallStatus.Canceled };
}
return tool;
},
);
const pendingHistoryItem = pendingHistoryItemRef.current;
pendingHistoryItem.tools = updatedTools;
addItem(pendingHistoryItem, userMessageTimestamp);
} else {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
}
setPendingHistoryItem(null);
}
addItem(
@ -412,32 +432,10 @@ export const useGeminiStream = (
}
if (outcome === ToolConfirmationOutcome.Cancel) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalConfirmationDetails) {
resultDisplay = {
fileDiff: (
originalConfirmationDetails as ToolEditConfirmationDetails
).fileDiff,
};
} else {
resultDisplay = `~~${(originalConfirmationDetails as ToolExecuteConfirmationDetails).command}~~`;
}
const functionResponse: Part = {
functionResponse: {
id: request.callId,
name: request.name,
response: { error: 'User rejected function call.' },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
error: new Error('User rejected function call.'),
};
// Update UI to show cancellation/error
updateFunctionResponseUI(responseInfo, ToolCallStatus.Error);
setStreamingState(StreamingState.Idle);
declineToolExecution(
'User rejected function call.',
ToolCallStatus.Error,
);
} else {
const tool = toolRegistry.getTool(request.name);
if (!tool) {
@ -445,7 +443,22 @@ export const useGeminiStream = (
`Tool "${request.name}" not found or is not registered.`,
);
}
const result = await tool.execute(request.args);
try {
abortControllerRef.current = new AbortController();
const result = await tool.execute(
request.args,
abortControllerRef.current.signal,
);
if (abortControllerRef.current.signal.aborted) {
declineToolExecution(
result.llmContent,
ToolCallStatus.Canceled,
);
return;
}
const functionResponse: Part = {
functionResponse: {
name: request.name,
@ -463,6 +476,42 @@ export const useGeminiStream = (
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
setStreamingState(StreamingState.Idle);
await submitQuery(functionResponse);
} finally {
abortControllerRef.current = null;
}
}
function declineToolExecution(
declineMessage: string,
status: ToolCallStatus,
) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalConfirmationDetails) {
resultDisplay = {
fileDiff: (
originalConfirmationDetails as ToolEditConfirmationDetails
).fileDiff,
};
} else {
resultDisplay = `~~${(originalConfirmationDetails as ToolExecuteConfirmationDetails).command}~~`;
}
const functionResponse: Part = {
functionResponse: {
id: request.callId,
name: request.name,
response: { error: declineMessage },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
error: new Error(declineMessage),
};
// Update UI to show cancellation/error
updateFunctionResponseUI(responseInfo, status);
setStreamingState(StreamingState.Idle);
}
};

View File

@ -64,10 +64,13 @@ export class GeminiClient {
.getTool('read_many_files') as ReadManyFilesTool;
if (readManyFilesTool) {
// Read all files in the target directory
const result = await readManyFilesTool.execute({
const result = await readManyFilesTool.execute(
{
paths: ['**/*'], // Read everything recursively
useDefaultExcludes: true, // Use default excludes
});
},
AbortSignal.timeout(30000),
);
if (result.llmContent) {
initialParts.push({
text: `\n--- Full File Context ---\n${result.llmContent}`,

View File

@ -36,7 +36,10 @@ export interface ServerTool {
name: string;
schema: FunctionDeclaration;
// The execute method signature might differ slightly or be wrapped
execute(params: Record<string, unknown>): Promise<ToolResult>;
execute(
params: Record<string, unknown>,
signal?: AbortSignal,
): Promise<ToolResult>;
shouldConfirmExecute(
params: Record<string, unknown>,
): Promise<ToolCallConfirmationDetails | false>;
@ -153,7 +156,7 @@ export class Turn {
if (confirmationDetails) {
return { ...pendingToolCall, confirmationDetails };
}
const result = await tool.execute(pendingToolCall.args);
const result = await tool.execute(pendingToolCall.args, signal);
return {
...pendingToolCall,
result,
@ -199,9 +202,13 @@ export class Turn {
resultDisplay: outcome.result?.returnDisplay,
error: outcome.error,
};
// If aborted we're already yielding the user cancellations elsewhere.
if (!signal?.aborted) {
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
}
}
}
private handlePendingFunctionCall(
fnCall: FunctionCall,

View File

@ -333,7 +333,10 @@ Expectation for parameters:
* @param params Parameters for the edit operation
* @returns Result of the edit operation
*/
async execute(params: EditToolParams): Promise<ToolResult> {
async execute(
params: EditToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {

View File

@ -138,7 +138,10 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
/**
* Executes the glob search with the given parameters
*/
async execute(params: GlobToolParams): Promise<ToolResult> {
async execute(
params: GlobToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {

View File

@ -166,7 +166,10 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
* @param params Parameters for the grep search
* @returns Result of the grep search
*/
async execute(params: GrepToolParams): Promise<ToolResult> {
async execute(
params: GrepToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
console.error(

View File

@ -184,7 +184,10 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
* @param params Parameters for the LS operation
* @returns Result of the LS operation
*/
async execute(params: LSToolParams): Promise<ToolResult> {
async execute(
params: LSToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return this.errorResult(

View File

@ -193,7 +193,10 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
* @param params Parameters for the file reading
* @returns Result with file contents
*/
async execute(params: ReadFileToolParams): Promise<ToolResult> {
async execute(
params: ReadFileToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {

View File

@ -237,7 +237,10 @@ Default excludes apply to common non-text files and large dependency directories
return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`;
}
async execute(params: ReadManyFilesParams): Promise<ToolResult> {
async execute(
params: ReadManyFilesParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {

View File

@ -118,7 +118,10 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
return confirmationDetails;
}
async execute(params: ShellToolParams): Promise<ToolResult> {
async execute(
params: ShellToolParams,
abortSignal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {
@ -174,18 +177,38 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
});
let code: number | null = null;
let signal: NodeJS.Signals | null = null;
shell.on(
'close',
(_code: number | null, _signal: NodeJS.Signals | null) => {
let processSignal: NodeJS.Signals | null = null;
const closeHandler = (
_code: number | null,
_signal: NodeJS.Signals | null,
) => {
code = _code;
signal = _signal;
},
);
processSignal = _signal;
};
shell.on('close', closeHandler);
const abortHandler = () => {
if (shell.pid) {
try {
// Kill the entire process group
process.kill(-shell.pid, 'SIGTERM');
} catch (_e) {
// Fallback to killing the main process if group kill fails
try {
shell.kill('SIGKILL'); // or 'SIGTERM'
} catch (_killError) {
// Ignore errors if the process is already dead
}
}
}
};
abortSignal.addEventListener('abort', abortHandler);
// wait for the shell to exit
await new Promise((resolve) => shell.on('close', resolve));
abortSignal.removeEventListener('abort', abortHandler);
// parse pids (pgrep output) from temporary file and remove it
const backgroundPIDs: number[] = [];
if (fs.existsSync(tempFilePath)) {
@ -205,19 +228,26 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
}
fs.unlinkSync(tempFilePath);
} else {
if (!abortSignal.aborted) {
console.error('missing pgrep output');
}
}
const llmContent = [
let llmContent = '';
if (abortSignal.aborted) {
llmContent = 'Command did not complete, it was cancelled by the user';
} else {
llmContent = [
`Command: ${params.command}`,
`Directory: ${params.directory || '(root)'}`,
`Stdout: ${stdout || '(empty)'}`,
`Stderr: ${stderr || '(empty)'}`,
`Error: ${error ?? '(none)'}`,
`Exit Code: ${code ?? '(none)'}`,
`Signal: ${signal ?? '(none)'}`,
`Signal: ${processSignal ?? '(none)'}`,
`Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
].join('\n');
}
const returnDisplay = this.config.getDebugMode() ? llmContent : output;

View File

@ -265,7 +265,10 @@ Use this tool for running build steps (\`npm install\`, \`make\`), linters (\`es
return confirmationDetails;
}
async execute(params: TerminalToolParams): Promise<ToolResult> {
async execute(
params: TerminalToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return {

View File

@ -64,7 +64,7 @@ export interface Tool<
* @param params Parameters for the tool execution
* @returns Result of the tool execution
*/
execute(params: TParams): Promise<TResult>;
execute(params: TParams, signal: AbortSignal): Promise<TResult>;
}
/**
@ -141,9 +141,10 @@ export abstract class BaseTool<
* Abstract method to execute the tool with the given parameters
* Must be implemented by derived classes
* @param params Parameters for the tool execution
* @param signal AbortSignal for tool cancellation
* @returns Result of the tool execution
*/
abstract execute(params: TParams): Promise<TResult>;
abstract execute(params: TParams, signal: AbortSignal): Promise<TResult>;
}
export interface ToolResult {

View File

@ -70,7 +70,10 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
return `Fetching content from ${displayUrl}`;
}
async execute(params: WebFetchToolParams): Promise<ToolResult> {
async execute(
params: WebFetchToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {

View File

@ -150,7 +150,10 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
return confirmationDetails;
}
async execute(params: WriteFileToolParams): Promise<ToolResult> {
async execute(
params: WriteFileToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateParams(params);
if (validationError) {
return {