fix(tool-scheduler): Correctly pipe cancellation signal to tool calls (#852)

This commit is contained in:
N. Taylor Mullen 2025-06-08 15:42:49 -07:00 committed by GitHub
parent 7868ef8229
commit f2ea78d0e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 235 additions and 209 deletions

View File

@ -18,6 +18,7 @@ import {
import { Config } from '@gemini-cli/core'; import { Config } from '@gemini-cli/core';
import { Part, PartListUnion } from '@google/genai'; import { Part, PartListUnion } from '@google/genai';
import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { Dispatch, SetStateAction } from 'react';
// --- MOCKS --- // --- MOCKS ---
const mockSendMessageStream = vi const mockSendMessageStream = vi
@ -309,16 +310,41 @@ describe('useGeminiStream', () => {
const client = geminiClient || mockConfig.getGeminiClient(); const client = geminiClient || mockConfig.getGeminiClient();
const { result, rerender } = renderHook(() => const { result, rerender } = renderHook(
(props: {
client: any;
addItem: UseHistoryManagerReturn['addItem'];
setShowHelp: Dispatch<SetStateAction<boolean>>;
config: Config;
onDebugMessage: (message: string) => void;
handleSlashCommand: (
command: PartListUnion,
) =>
| import('./slashCommandProcessor.js').SlashCommandActionReturn
| boolean;
shellModeActive: boolean;
}) =>
useGeminiStream( useGeminiStream(
client, props.client,
mockAddItem as unknown as UseHistoryManagerReturn['addItem'], props.addItem,
mockSetShowHelp, props.setShowHelp,
mockConfig, props.config,
mockOnDebugMessage, props.onDebugMessage,
mockHandleSlashCommand, props.handleSlashCommand,
false, // shellModeActive props.shellModeActive,
), ),
{
initialProps: {
client,
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
shellModeActive: false,
},
},
); );
return { return {
result, result,
@ -326,7 +352,6 @@ describe('useGeminiStream', () => {
mockMarkToolsAsSubmitted, mockMarkToolsAsSubmitted,
mockSendMessageStream, mockSendMessageStream,
client, client,
// mockFilter removed
}; };
}; };
@ -423,24 +448,29 @@ describe('useGeminiStream', () => {
} as TrackedCancelledToolCall, } as TrackedCancelledToolCall,
]; ];
const hookResult = await act(async () =>
renderTestHook(simplifiedToolCalls),
);
const { const {
rerender,
mockMarkToolsAsSubmitted, mockMarkToolsAsSubmitted,
mockSendMessageStream: localMockSendMessageStream, mockSendMessageStream: localMockSendMessageStream,
} = hookResult!; client,
} = renderTestHook(simplifiedToolCalls);
// It seems the initial render + effect run should be enough. act(() => {
// If rerender was for a specific state change, it might still be needed. rerender({
// For now, let's test if the initial effect run (covered by the first act) is sufficient. client,
// If not, we can add back: await act(async () => { rerender({}); }); addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
setShowHelp: mockSetShowHelp,
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']); config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
shellModeActive: false,
});
});
await waitFor(() => { await waitFor(() => {
expect(localMockSendMessageStream).toHaveBeenCalledTimes(1); expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
expect(localMockSendMessageStream).toHaveBeenCalledTimes(0);
}); });
const expectedMergedResponse = mergePartListUnions([ const expectedMergedResponse = mergePartListUnions([
@ -479,12 +509,21 @@ describe('useGeminiStream', () => {
client, client,
); );
await act(async () => { act(() => {
rerender({} as any); rerender({
client,
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
shellModeActive: false,
});
}); });
await waitFor(() => { await waitFor(() => {
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']); expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
expect(client.addHistory).toHaveBeenCalledTimes(2); expect(client.addHistory).toHaveBeenCalledTimes(2);
expect(client.addHistory).toHaveBeenCalledWith({ expect(client.addHistory).toHaveBeenCalledWith({
role: 'user', role: 'user',

View File

@ -83,12 +83,8 @@ export const useGeminiStream = (
useStateAndRef<HistoryItemWithoutId | null>(null); useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger(); const logger = useLogger();
const [ const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
toolCalls, useReactToolScheduler(
scheduleToolCalls,
cancelAllToolCalls,
markToolsAsSubmitted,
] = useReactToolScheduler(
(completedToolCallsFromScheduler) => { (completedToolCallsFromScheduler) => {
// This onComplete is called when ALL scheduled tools for a given batch are done. // This onComplete is called when ALL scheduled tools for a given batch are done.
if (completedToolCallsFromScheduler.length > 0) { if (completedToolCallsFromScheduler.length > 0) {
@ -143,10 +139,15 @@ export const useGeminiStream = (
return StreamingState.Idle; return StreamingState.Idle;
}, [isResponding, toolCalls]); }, [isResponding, toolCalls]);
useEffect(() => {
if (streamingState === StreamingState.Idle) {
abortControllerRef.current = null;
}
}, [streamingState]);
useInput((_input, key) => { useInput((_input, key) => {
if (streamingState !== StreamingState.Idle && key.escape) { if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort(); abortControllerRef.current?.abort();
cancelAllToolCalls(); // Also cancel any pending/executing tool calls
} }
}); });
@ -191,7 +192,7 @@ export const useGeminiStream = (
name: toolName, name: toolName,
args: toolArgs, args: toolArgs,
}; };
scheduleToolCalls([toolCallRequest]); scheduleToolCalls([toolCallRequest], abortSignal);
} }
return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool
} }
@ -330,9 +331,8 @@ export const useGeminiStream = (
userMessageTimestamp, userMessageTimestamp,
); );
setIsResponding(false); setIsResponding(false);
cancelAllToolCalls();
}, },
[addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls], [addItem, pendingHistoryItemRef, setPendingHistoryItem],
); );
const handleErrorEvent = useCallback( const handleErrorEvent = useCallback(
@ -365,6 +365,7 @@ export const useGeminiStream = (
async ( async (
stream: AsyncIterable<GeminiEvent>, stream: AsyncIterable<GeminiEvent>,
userMessageTimestamp: number, userMessageTimestamp: number,
signal: AbortSignal,
): Promise<StreamProcessingStatus> => { ): Promise<StreamProcessingStatus> => {
let geminiMessageBuffer = ''; let geminiMessageBuffer = '';
const toolCallRequests: ToolCallRequestInfo[] = []; const toolCallRequests: ToolCallRequestInfo[] = [];
@ -401,7 +402,7 @@ export const useGeminiStream = (
} }
} }
if (toolCallRequests.length > 0) { if (toolCallRequests.length > 0) {
scheduleToolCalls(toolCallRequests); scheduleToolCalls(toolCallRequests, signal);
} }
return StreamProcessingStatus.Completed; return StreamProcessingStatus.Completed;
}, },
@ -453,6 +454,7 @@ export const useGeminiStream = (
const processingStatus = await processGeminiStreamEvents( const processingStatus = await processGeminiStreamEvents(
stream, stream,
userMessageTimestamp, userMessageTimestamp,
abortSignal,
); );
if (processingStatus === StreamProcessingStatus.UserCancelled) { if (processingStatus === StreamProcessingStatus.UserCancelled) {
@ -476,7 +478,6 @@ export const useGeminiStream = (
); );
} }
} finally { } finally {
abortControllerRef.current = null; // Always reset
setIsResponding(false); setIsResponding(false);
} }
}, },

View File

@ -32,8 +32,8 @@ import {
export type ScheduleFn = ( export type ScheduleFn = (
request: ToolCallRequestInfo | ToolCallRequestInfo[], request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
) => void; ) => void;
export type CancelFn = (reason?: string) => void;
export type MarkToolsAsSubmittedFn = (callIds: string[]) => void; export type MarkToolsAsSubmittedFn = (callIds: string[]) => void;
export type TrackedScheduledToolCall = ScheduledToolCall & { export type TrackedScheduledToolCall = ScheduledToolCall & {
@ -69,7 +69,7 @@ export function useReactToolScheduler(
setPendingHistoryItem: React.Dispatch< setPendingHistoryItem: React.Dispatch<
React.SetStateAction<HistoryItemWithoutId | null> React.SetStateAction<HistoryItemWithoutId | null>
>, >,
): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] { ): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] {
const [toolCallsForDisplay, setToolCallsForDisplay] = useState< const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
TrackedToolCall[] TrackedToolCall[]
>([]); >([]);
@ -172,15 +172,11 @@ export function useReactToolScheduler(
); );
const schedule: ScheduleFn = useCallback( const schedule: ScheduleFn = useCallback(
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => { async (
scheduler.schedule(request); request: ToolCallRequestInfo | ToolCallRequestInfo[],
}, signal: AbortSignal,
[scheduler], ) => {
); scheduler.schedule(request, signal);
const cancel: CancelFn = useCallback(
(reason: string = 'unspecified') => {
scheduler.cancelAll(reason);
}, },
[scheduler], [scheduler],
); );
@ -198,7 +194,7 @@ export function useReactToolScheduler(
[], [],
); );
return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted]; return [toolCallsForDisplay, schedule, markToolsAsSubmitted];
} }
/** /**

View File

@ -137,7 +137,7 @@ describe('useReactToolScheduler in YOLO Mode', () => {
}; };
act(() => { act(() => {
schedule(request); schedule(request, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
@ -290,7 +290,7 @@ describe('useReactToolScheduler', () => {
}; };
act(() => { act(() => {
schedule(request); schedule(request, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@ -337,7 +337,7 @@ describe('useReactToolScheduler', () => {
}; };
act(() => { act(() => {
schedule(request); schedule(request, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@ -374,7 +374,7 @@ describe('useReactToolScheduler', () => {
}; };
act(() => { act(() => {
schedule(request); schedule(request, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@ -410,7 +410,7 @@ describe('useReactToolScheduler', () => {
}; };
act(() => { act(() => {
schedule(request); schedule(request, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@ -451,7 +451,7 @@ describe('useReactToolScheduler', () => {
}; };
act(() => { act(() => {
schedule(request); schedule(request, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@ -507,7 +507,7 @@ describe('useReactToolScheduler', () => {
}; };
act(() => { act(() => {
schedule(request); schedule(request, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@ -579,7 +579,7 @@ describe('useReactToolScheduler', () => {
}; };
act(() => { act(() => {
schedule(request); schedule(request, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@ -634,102 +634,6 @@ describe('useReactToolScheduler', () => {
expect(result.current[0]).toEqual([]); expect(result.current[0]).toEqual([]);
}); });
it.skip('should cancel tool calls before execution (e.g. when status is scheduled)', async () => {
mockToolRegistry.getTool.mockReturnValue(mockTool);
(mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null);
(mockTool.execute as Mock).mockReturnValue(new Promise(() => {}));
const { result } = renderScheduler();
const schedule = result.current[1];
const cancel = result.current[2];
const request: ToolCallRequestInfo = {
callId: 'cancelCall',
name: 'mockTool',
args: {},
};
act(() => {
schedule(request);
});
await act(async () => {
await vi.runAllTimersAsync();
});
act(() => {
cancel();
});
await act(async () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalledWith([
expect.objectContaining({
status: 'cancelled',
request,
response: expect.objectContaining({
responseParts: expect.arrayContaining([
expect.objectContaining({
functionResponse: expect.objectContaining({
response: expect.objectContaining({
error:
'[Operation Cancelled] Reason: User cancelled before execution',
}),
}),
}),
]),
}),
}),
]);
expect(mockTool.execute).not.toHaveBeenCalled();
expect(result.current[0]).toEqual([]);
});
it.skip('should cancel tool calls that are awaiting approval', async () => {
mockToolRegistry.getTool.mockReturnValue(mockToolRequiresConfirmation);
const { result } = renderScheduler();
const schedule = result.current[1];
const cancelFn = result.current[2];
const request: ToolCallRequestInfo = {
callId: 'cancelApprovalCall',
name: 'mockToolRequiresConfirmation',
args: {},
};
act(() => {
schedule(request);
});
await act(async () => {
await vi.runAllTimersAsync();
});
act(() => {
cancelFn();
});
await act(async () => {
await vi.runAllTimersAsync();
});
expect(onComplete).toHaveBeenCalledWith([
expect.objectContaining({
status: 'cancelled',
request,
response: expect.objectContaining({
responseParts: expect.arrayContaining([
expect.objectContaining({
functionResponse: expect.objectContaining({
response: expect.objectContaining({
error:
'[Operation Cancelled] Reason: User cancelled during approval',
}),
}),
}),
]),
}),
}),
]);
expect(result.current[0]).toEqual([]);
});
it('should schedule and execute multiple tool calls', async () => { it('should schedule and execute multiple tool calls', async () => {
const tool1 = { const tool1 = {
...mockTool, ...mockTool,
@ -766,7 +670,7 @@ describe('useReactToolScheduler', () => {
]; ];
act(() => { act(() => {
schedule(requests); schedule(requests, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
@ -848,13 +752,13 @@ describe('useReactToolScheduler', () => {
}; };
act(() => { act(() => {
schedule(request1); schedule(request1, new AbortController().signal);
}); });
await act(async () => { await act(async () => {
await vi.runAllTimersAsync(); await vi.runAllTimersAsync();
}); });
expect(() => schedule(request2)).toThrow( expect(() => schedule(request2, new AbortController().signal)).toThrow(
'Cannot schedule tool calls while other tool calls are running', 'Cannot schedule tool calls while other tool calls are running',
); );

View File

@ -4,9 +4,110 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { describe, it, expect } from 'vitest'; /* eslint-disable @typescript-eslint/no-explicit-any */
import { convertToFunctionResponse } from './coreToolScheduler.js'; import { describe, it, expect, vi } from 'vitest';
import {
CoreToolScheduler,
ToolCall,
ValidatingToolCall,
} from './coreToolScheduler.js';
import {
BaseTool,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolResult,
} from '../index.js';
import { Part, PartListUnion } from '@google/genai'; import { Part, PartListUnion } from '@google/genai';
import { convertToFunctionResponse } from './coreToolScheduler.js';
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
shouldConfirm = false;
executeFn = vi.fn();
constructor(name = 'mockTool') {
super(name, name, 'A mock tool', {});
}
async shouldConfirmExecute(
_params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) {
return {
type: 'exec',
title: 'Confirm Mock Tool',
command: 'do_thing',
rootCommand: 'do_thing',
onConfirm: async () => {},
};
}
return false;
}
async execute(
params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolResult> {
this.executeFn(params);
return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' };
}
}
describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool();
mockTool.shouldConfirm = true;
const toolRegistry = {
getTool: () => mockTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
config: {} as any,
registerTool: () => {},
getToolByName: () => mockTool,
getToolByDisplayName: () => mockTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const scheduler = new CoreToolScheduler({
toolRegistry: Promise.resolve(toolRegistry as any),
onAllToolCallsComplete,
onToolCallsUpdate,
});
const abortController = new AbortController();
const request = { callId: '1', name: 'mockTool', args: {} };
abortController.abort();
await scheduler.schedule([request], abortController.signal);
const _waitingCall = onToolCallsUpdate.mock
.calls[1][0][0] as ValidatingToolCall;
const confirmationDetails = await mockTool.shouldConfirmExecute(
{},
abortController.signal,
);
if (confirmationDetails) {
await scheduler.handleConfirmationResponse(
'1',
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
abortController.signal,
);
}
expect(onAllToolCallsComplete).toHaveBeenCalled();
const completedCalls = onAllToolCallsComplete.mock
.calls[0][0] as ToolCall[];
expect(completedCalls[0].status).toBe('cancelled');
});
});
describe('convertToFunctionResponse', () => { describe('convertToFunctionResponse', () => {
const toolName = 'testTool'; const toolName = 'testTool';

View File

@ -208,7 +208,6 @@ interface CoreToolSchedulerOptions {
export class CoreToolScheduler { export class CoreToolScheduler {
private toolRegistry: Promise<ToolRegistry>; private toolRegistry: Promise<ToolRegistry>;
private toolCalls: ToolCall[] = []; private toolCalls: ToolCall[] = [];
private abortController: AbortController;
private outputUpdateHandler?: OutputUpdateHandler; private outputUpdateHandler?: OutputUpdateHandler;
private onAllToolCallsComplete?: AllToolCallsCompleteHandler; private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
private onToolCallsUpdate?: ToolCallsUpdateHandler; private onToolCallsUpdate?: ToolCallsUpdateHandler;
@ -220,7 +219,6 @@ export class CoreToolScheduler {
this.onAllToolCallsComplete = options.onAllToolCallsComplete; this.onAllToolCallsComplete = options.onAllToolCallsComplete;
this.onToolCallsUpdate = options.onToolCallsUpdate; this.onToolCallsUpdate = options.onToolCallsUpdate;
this.approvalMode = options.approvalMode ?? ApprovalMode.DEFAULT; this.approvalMode = options.approvalMode ?? ApprovalMode.DEFAULT;
this.abortController = new AbortController();
} }
private setStatusInternal( private setStatusInternal(
@ -379,6 +377,7 @@ export class CoreToolScheduler {
async schedule( async schedule(
request: ToolCallRequestInfo | ToolCallRequestInfo[], request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal,
): Promise<void> { ): Promise<void> {
if (this.isRunning()) { if (this.isRunning()) {
throw new Error( throw new Error(
@ -426,7 +425,7 @@ export class CoreToolScheduler {
} else { } else {
const confirmationDetails = await toolInstance.shouldConfirmExecute( const confirmationDetails = await toolInstance.shouldConfirmExecute(
reqInfo.args, reqInfo.args,
this.abortController.signal, signal,
); );
if (confirmationDetails) { if (confirmationDetails) {
@ -438,6 +437,7 @@ export class CoreToolScheduler {
reqInfo.callId, reqInfo.callId,
originalOnConfirm, originalOnConfirm,
outcome, outcome,
signal,
), ),
}; };
this.setStatusInternal( this.setStatusInternal(
@ -460,7 +460,7 @@ export class CoreToolScheduler {
); );
} }
} }
this.attemptExecutionOfScheduledCalls(); this.attemptExecutionOfScheduledCalls(signal);
this.checkAndNotifyCompletion(); this.checkAndNotifyCompletion();
} }
@ -468,6 +468,7 @@ export class CoreToolScheduler {
callId: string, callId: string,
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>, originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
outcome: ToolConfirmationOutcome, outcome: ToolConfirmationOutcome,
signal: AbortSignal,
): Promise<void> { ): Promise<void> {
const toolCall = this.toolCalls.find( const toolCall = this.toolCalls.find(
(c) => c.request.callId === callId && c.status === 'awaiting_approval', (c) => c.request.callId === callId && c.status === 'awaiting_approval',
@ -477,7 +478,7 @@ export class CoreToolScheduler {
await originalOnConfirm(outcome); await originalOnConfirm(outcome);
} }
if (outcome === ToolConfirmationOutcome.Cancel) { if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
this.setStatusInternal( this.setStatusInternal(
callId, callId,
'cancelled', 'cancelled',
@ -497,7 +498,7 @@ export class CoreToolScheduler {
const modifyResults = await editTool.onModify( const modifyResults = await editTool.onModify(
waitingToolCall.request.args as unknown as EditToolParams, waitingToolCall.request.args as unknown as EditToolParams,
this.abortController.signal, signal,
outcome, outcome,
); );
@ -513,10 +514,10 @@ export class CoreToolScheduler {
} else { } else {
this.setStatusInternal(callId, 'scheduled'); this.setStatusInternal(callId, 'scheduled');
} }
this.attemptExecutionOfScheduledCalls(); this.attemptExecutionOfScheduledCalls(signal);
} }
private attemptExecutionOfScheduledCalls(): void { private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
const allCallsFinalOrScheduled = this.toolCalls.every( const allCallsFinalOrScheduled = this.toolCalls.every(
(call) => (call) =>
call.status === 'scheduled' || call.status === 'scheduled' ||
@ -553,17 +554,13 @@ export class CoreToolScheduler {
: undefined; : undefined;
scheduledCall.tool scheduledCall.tool
.execute( .execute(scheduledCall.request.args, signal, liveOutputCallback)
scheduledCall.request.args,
this.abortController.signal,
liveOutputCallback,
)
.then((toolResult: ToolResult) => { .then((toolResult: ToolResult) => {
if (this.abortController.signal.aborted) { if (signal.aborted) {
this.setStatusInternal( this.setStatusInternal(
callId, callId,
'cancelled', 'cancelled',
this.abortController.signal.reason || 'Execution aborted.', 'User cancelled tool execution.',
); );
return; return;
} }
@ -613,29 +610,10 @@ export class CoreToolScheduler {
if (this.onAllToolCallsComplete) { if (this.onAllToolCallsComplete) {
this.onAllToolCallsComplete(completedCalls); this.onAllToolCallsComplete(completedCalls);
} }
this.abortController = new AbortController();
this.notifyToolCallsUpdate(); this.notifyToolCallsUpdate();
} }
} }
cancelAll(reason: string = 'User initiated cancellation.'): void {
if (!this.abortController.signal.aborted) {
this.abortController.abort(reason);
}
this.abortController = new AbortController();
const callsToCancel = [...this.toolCalls];
callsToCancel.forEach((call) => {
if (
call.status !== 'error' &&
call.status !== 'success' &&
call.status !== 'cancelled'
) {
this.setStatusInternal(call.request.callId, 'cancelled', reason);
}
});
}
private notifyToolCallsUpdate(): void { private notifyToolCallsUpdate(): void {
if (this.onToolCallsUpdate) { if (this.onToolCallsUpdate) {
this.onToolCallsUpdate([...this.toolCalls]); this.onToolCallsUpdate([...this.toolCalls]);

View File

@ -162,6 +162,13 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
}; };
} }
if (abortSignal.aborted) {
return {
llmContent: 'Command was cancelled by user before it could start.',
returnDisplay: 'Command cancelled by user.',
};
}
// wrap command to append subprocess pids (via pgrep) to temporary file // wrap command to append subprocess pids (via pgrep) to temporary file
const tempFileName = `shell_pgrep_${crypto.randomBytes(6).toString('hex')}.tmp`; const tempFileName = `shell_pgrep_${crypto.randomBytes(6).toString('hex')}.tmp`;
const tempFilePath = path.join(os.tmpdir(), tempFileName); const tempFilePath = path.join(os.tmpdir(), tempFileName);