Enable tools to cancel active execution.
- Plumbed abort signals through to tools - Updated the shell tool to properly cancel active requests by killing the entire child process tree of the underlying shell process and then report that the shell itself was canceled. Fixes https://b.corp.google.com/issues/416829935
This commit is contained in:
parent
090198a7d6
commit
6b518dc9e4
|
@ -26,6 +26,7 @@ interface HandleAtCommandParams {
|
|||
addItem: UseHistoryManagerReturn['addItem'];
|
||||
setDebugMessage: React.Dispatch<React.SetStateAction<string>>;
|
||||
messageId: number;
|
||||
signal: AbortSignal;
|
||||
}
|
||||
|
||||
interface HandleAtCommandResult {
|
||||
|
@ -90,6 +91,7 @@ export async function handleAtCommand({
|
|||
addItem,
|
||||
setDebugMessage,
|
||||
messageId: userMessageTimestamp,
|
||||
signal,
|
||||
}: HandleAtCommandParams): Promise<HandleAtCommandResult> {
|
||||
const trimmedQuery = query.trim();
|
||||
const parsedCommand = parseAtCommand(trimmedQuery);
|
||||
|
@ -163,7 +165,7 @@ export async function handleAtCommand({
|
|||
let toolCallDisplay: IndividualToolCallDisplay;
|
||||
|
||||
try {
|
||||
const result = await readManyFilesTool.execute(toolArgs);
|
||||
const result = await readManyFilesTool.execute(toolArgs, signal);
|
||||
const fileContent = result.llmContent || '';
|
||||
|
||||
toolCallDisplay = {
|
||||
|
|
|
@ -89,7 +89,7 @@ export const useGeminiStream = (
|
|||
}, [config, addItem]);
|
||||
|
||||
useInput((_input, key) => {
|
||||
if (streamingState === StreamingState.Responding && key.escape) {
|
||||
if (streamingState !== StreamingState.Idle && key.escape) {
|
||||
abortControllerRef.current?.abort();
|
||||
}
|
||||
});
|
||||
|
@ -104,6 +104,9 @@ export const useGeminiStream = (
|
|||
|
||||
setShowHelp(false);
|
||||
|
||||
abortControllerRef.current ??= new AbortController();
|
||||
const signal = abortControllerRef.current.signal;
|
||||
|
||||
if (typeof query === 'string') {
|
||||
const trimmedQuery = query.trim();
|
||||
setDebugMessage(`User query: '${trimmedQuery}'`);
|
||||
|
@ -120,6 +123,7 @@ export const useGeminiStream = (
|
|||
addItem,
|
||||
setDebugMessage,
|
||||
messageId: userMessageTimestamp,
|
||||
signal,
|
||||
});
|
||||
if (!atCommandResult.shouldProceed) return;
|
||||
queryToSendToGemini = atCommandResult.processedQuery;
|
||||
|
@ -165,9 +169,6 @@ export const useGeminiStream = (
|
|||
const chat = chatSessionRef.current;
|
||||
|
||||
try {
|
||||
abortControllerRef.current = new AbortController();
|
||||
const signal = abortControllerRef.current.signal;
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
chat,
|
||||
queryToSendToGemini,
|
||||
|
@ -294,7 +295,26 @@ export const useGeminiStream = (
|
|||
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
||||
// Flush out existing pending history item.
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
// If the pending item is a tool_group, update statuses to Canceled
|
||||
if (pendingHistoryItemRef.current.type === 'tool_group') {
|
||||
const updatedTools = pendingHistoryItemRef.current.tools.map(
|
||||
(tool) => {
|
||||
if (
|
||||
tool.status === ToolCallStatus.Pending ||
|
||||
tool.status === ToolCallStatus.Confirming ||
|
||||
tool.status === ToolCallStatus.Executing
|
||||
) {
|
||||
return { ...tool, status: ToolCallStatus.Canceled };
|
||||
}
|
||||
return tool;
|
||||
},
|
||||
);
|
||||
const pendingHistoryItem = pendingHistoryItemRef.current;
|
||||
pendingHistoryItem.tools = updatedTools;
|
||||
addItem(pendingHistoryItem, userMessageTimestamp);
|
||||
} else {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
}
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
addItem(
|
||||
|
@ -412,6 +432,59 @@ export const useGeminiStream = (
|
|||
}
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.Cancel) {
|
||||
declineToolExecution(
|
||||
'User rejected function call.',
|
||||
ToolCallStatus.Error,
|
||||
);
|
||||
} else {
|
||||
const tool = toolRegistry.getTool(request.name);
|
||||
if (!tool) {
|
||||
throw new Error(
|
||||
`Tool "${request.name}" not found or is not registered.`,
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
abortControllerRef.current = new AbortController();
|
||||
const result = await tool.execute(
|
||||
request.args,
|
||||
abortControllerRef.current.signal,
|
||||
);
|
||||
|
||||
if (abortControllerRef.current.signal.aborted) {
|
||||
declineToolExecution(
|
||||
result.llmContent,
|
||||
ToolCallStatus.Canceled,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
name: request.name,
|
||||
id: request.callId,
|
||||
response: { output: result.llmContent },
|
||||
},
|
||||
};
|
||||
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay: result.returnDisplay,
|
||||
error: undefined,
|
||||
};
|
||||
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
|
||||
setStreamingState(StreamingState.Idle);
|
||||
await submitQuery(functionResponse);
|
||||
} finally {
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
}
|
||||
|
||||
function declineToolExecution(
|
||||
declineMessage: string,
|
||||
status: ToolCallStatus,
|
||||
) {
|
||||
let resultDisplay: ToolResultDisplay | undefined;
|
||||
if ('fileDiff' in originalConfirmationDetails) {
|
||||
resultDisplay = {
|
||||
|
@ -426,43 +499,19 @@ export const useGeminiStream = (
|
|||
functionResponse: {
|
||||
id: request.callId,
|
||||
name: request.name,
|
||||
response: { error: 'User rejected function call.' },
|
||||
response: { error: declineMessage },
|
||||
},
|
||||
};
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay,
|
||||
error: new Error('User rejected function call.'),
|
||||
};
|
||||
// Update UI to show cancellation/error
|
||||
updateFunctionResponseUI(responseInfo, ToolCallStatus.Error);
|
||||
setStreamingState(StreamingState.Idle);
|
||||
} else {
|
||||
const tool = toolRegistry.getTool(request.name);
|
||||
if (!tool) {
|
||||
throw new Error(
|
||||
`Tool "${request.name}" not found or is not registered.`,
|
||||
);
|
||||
}
|
||||
const result = await tool.execute(request.args);
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
name: request.name,
|
||||
id: request.callId,
|
||||
response: { output: result.llmContent },
|
||||
},
|
||||
error: new Error(declineMessage),
|
||||
};
|
||||
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay: result.returnDisplay,
|
||||
error: undefined,
|
||||
};
|
||||
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
|
||||
// Update UI to show cancellation/error
|
||||
updateFunctionResponseUI(responseInfo, status);
|
||||
setStreamingState(StreamingState.Idle);
|
||||
await submitQuery(functionResponse);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -64,10 +64,13 @@ export class GeminiClient {
|
|||
.getTool('read_many_files') as ReadManyFilesTool;
|
||||
if (readManyFilesTool) {
|
||||
// Read all files in the target directory
|
||||
const result = await readManyFilesTool.execute({
|
||||
paths: ['**/*'], // Read everything recursively
|
||||
useDefaultExcludes: true, // Use default excludes
|
||||
});
|
||||
const result = await readManyFilesTool.execute(
|
||||
{
|
||||
paths: ['**/*'], // Read everything recursively
|
||||
useDefaultExcludes: true, // Use default excludes
|
||||
},
|
||||
AbortSignal.timeout(30000),
|
||||
);
|
||||
if (result.llmContent) {
|
||||
initialParts.push({
|
||||
text: `\n--- Full File Context ---\n${result.llmContent}`,
|
||||
|
|
|
@ -36,7 +36,10 @@ export interface ServerTool {
|
|||
name: string;
|
||||
schema: FunctionDeclaration;
|
||||
// The execute method signature might differ slightly or be wrapped
|
||||
execute(params: Record<string, unknown>): Promise<ToolResult>;
|
||||
execute(
|
||||
params: Record<string, unknown>,
|
||||
signal?: AbortSignal,
|
||||
): Promise<ToolResult>;
|
||||
shouldConfirmExecute(
|
||||
params: Record<string, unknown>,
|
||||
): Promise<ToolCallConfirmationDetails | false>;
|
||||
|
@ -153,7 +156,7 @@ export class Turn {
|
|||
if (confirmationDetails) {
|
||||
return { ...pendingToolCall, confirmationDetails };
|
||||
}
|
||||
const result = await tool.execute(pendingToolCall.args);
|
||||
const result = await tool.execute(pendingToolCall.args, signal);
|
||||
return {
|
||||
...pendingToolCall,
|
||||
result,
|
||||
|
@ -199,7 +202,11 @@ export class Turn {
|
|||
resultDisplay: outcome.result?.returnDisplay,
|
||||
error: outcome.error,
|
||||
};
|
||||
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
|
||||
|
||||
// If aborted we're already yielding the user cancellations elsewhere.
|
||||
if (!signal?.aborted) {
|
||||
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -333,7 +333,10 @@ Expectation for parameters:
|
|||
* @param params Parameters for the edit operation
|
||||
* @returns Result of the edit operation
|
||||
*/
|
||||
async execute(params: EditToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: EditToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
|
|
|
@ -138,7 +138,10 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
|
|||
/**
|
||||
* Executes the glob search with the given parameters
|
||||
*/
|
||||
async execute(params: GlobToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: GlobToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
|
|
|
@ -166,7 +166,10 @@ export class GrepTool extends BaseTool<GrepToolParams, ToolResult> {
|
|||
* @param params Parameters for the grep search
|
||||
* @returns Result of the grep search
|
||||
*/
|
||||
async execute(params: GrepToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: GrepToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
console.error(
|
||||
|
|
|
@ -184,7 +184,10 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
|
|||
* @param params Parameters for the LS operation
|
||||
* @returns Result of the LS operation
|
||||
*/
|
||||
async execute(params: LSToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: LSToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return this.errorResult(
|
||||
|
|
|
@ -193,7 +193,10 @@ export class ReadFileTool extends BaseTool<ReadFileToolParams, ToolResult> {
|
|||
* @param params Parameters for the file reading
|
||||
* @returns Result with file contents
|
||||
*/
|
||||
async execute(params: ReadFileToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: ReadFileToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
|
|
|
@ -237,7 +237,10 @@ Default excludes apply to common non-text files and large dependency directories
|
|||
return `Will attempt to read and concatenate files ${pathDesc}. ${excludeDesc}. File encoding: ${DEFAULT_ENCODING}. Separator: "${DEFAULT_OUTPUT_SEPARATOR_FORMAT.replace('{filePath}', 'path/to/file.ext')}".`;
|
||||
}
|
||||
|
||||
async execute(params: ReadManyFilesParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: ReadManyFilesParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
|
|
|
@ -118,7 +118,10 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
|||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(params: ShellToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: ShellToolParams,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
|
@ -174,18 +177,38 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
|||
});
|
||||
|
||||
let code: number | null = null;
|
||||
let signal: NodeJS.Signals | null = null;
|
||||
shell.on(
|
||||
'close',
|
||||
(_code: number | null, _signal: NodeJS.Signals | null) => {
|
||||
code = _code;
|
||||
signal = _signal;
|
||||
},
|
||||
);
|
||||
let processSignal: NodeJS.Signals | null = null;
|
||||
const closeHandler = (
|
||||
_code: number | null,
|
||||
_signal: NodeJS.Signals | null,
|
||||
) => {
|
||||
code = _code;
|
||||
processSignal = _signal;
|
||||
};
|
||||
shell.on('close', closeHandler);
|
||||
|
||||
const abortHandler = () => {
|
||||
if (shell.pid) {
|
||||
try {
|
||||
// Kill the entire process group
|
||||
process.kill(-shell.pid, 'SIGTERM');
|
||||
} catch (_e) {
|
||||
// Fallback to killing the main process if group kill fails
|
||||
try {
|
||||
shell.kill('SIGKILL'); // or 'SIGTERM'
|
||||
} catch (_killError) {
|
||||
// Ignore errors if the process is already dead
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
abortSignal.addEventListener('abort', abortHandler);
|
||||
|
||||
// wait for the shell to exit
|
||||
await new Promise((resolve) => shell.on('close', resolve));
|
||||
|
||||
abortSignal.removeEventListener('abort', abortHandler);
|
||||
|
||||
// parse pids (pgrep output) from temporary file and remove it
|
||||
const backgroundPIDs: number[] = [];
|
||||
if (fs.existsSync(tempFilePath)) {
|
||||
|
@ -205,19 +228,26 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
|||
}
|
||||
fs.unlinkSync(tempFilePath);
|
||||
} else {
|
||||
console.error('missing pgrep output');
|
||||
if (!abortSignal.aborted) {
|
||||
console.error('missing pgrep output');
|
||||
}
|
||||
}
|
||||
|
||||
const llmContent = [
|
||||
`Command: ${params.command}`,
|
||||
`Directory: ${params.directory || '(root)'}`,
|
||||
`Stdout: ${stdout || '(empty)'}`,
|
||||
`Stderr: ${stderr || '(empty)'}`,
|
||||
`Error: ${error ?? '(none)'}`,
|
||||
`Exit Code: ${code ?? '(none)'}`,
|
||||
`Signal: ${signal ?? '(none)'}`,
|
||||
`Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
|
||||
].join('\n');
|
||||
let llmContent = '';
|
||||
if (abortSignal.aborted) {
|
||||
llmContent = 'Command did not complete, it was cancelled by the user';
|
||||
} else {
|
||||
llmContent = [
|
||||
`Command: ${params.command}`,
|
||||
`Directory: ${params.directory || '(root)'}`,
|
||||
`Stdout: ${stdout || '(empty)'}`,
|
||||
`Stderr: ${stderr || '(empty)'}`,
|
||||
`Error: ${error ?? '(none)'}`,
|
||||
`Exit Code: ${code ?? '(none)'}`,
|
||||
`Signal: ${processSignal ?? '(none)'}`,
|
||||
`Background PIDs: ${backgroundPIDs.length ? backgroundPIDs.join(', ') : '(none)'}`,
|
||||
].join('\n');
|
||||
}
|
||||
|
||||
const returnDisplay = this.config.getDebugMode() ? llmContent : output;
|
||||
|
||||
|
|
|
@ -265,7 +265,10 @@ Use this tool for running build steps (\`npm install\`, \`make\`), linters (\`es
|
|||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(params: TerminalToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: TerminalToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateToolParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
|
|
|
@ -64,7 +64,7 @@ export interface Tool<
|
|||
* @param params Parameters for the tool execution
|
||||
* @returns Result of the tool execution
|
||||
*/
|
||||
execute(params: TParams): Promise<TResult>;
|
||||
execute(params: TParams, signal: AbortSignal): Promise<TResult>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -141,9 +141,10 @@ export abstract class BaseTool<
|
|||
* Abstract method to execute the tool with the given parameters
|
||||
* Must be implemented by derived classes
|
||||
* @param params Parameters for the tool execution
|
||||
* @param signal AbortSignal for tool cancellation
|
||||
* @returns Result of the tool execution
|
||||
*/
|
||||
abstract execute(params: TParams): Promise<TResult>;
|
||||
abstract execute(params: TParams, signal: AbortSignal): Promise<TResult>;
|
||||
}
|
||||
|
||||
export interface ToolResult {
|
||||
|
|
|
@ -70,7 +70,10 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
|||
return `Fetching content from ${displayUrl}`;
|
||||
}
|
||||
|
||||
async execute(params: WebFetchToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: WebFetchToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
|
|
|
@ -150,7 +150,10 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
|||
return confirmationDetails;
|
||||
}
|
||||
|
||||
async execute(params: WriteFileToolParams): Promise<ToolResult> {
|
||||
async execute(
|
||||
params: WriteFileToolParams,
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolResult> {
|
||||
const validationError = this.validateParams(params);
|
||||
if (validationError) {
|
||||
return {
|
||||
|
|
Loading…
Reference in New Issue