fix(tool-scheduler): Correctly pipe cancellation signal to tool calls (#852)
This commit is contained in:
parent
7868ef8229
commit
f2ea78d0e4
|
@ -18,6 +18,7 @@ import {
|
||||||
import { Config } from '@gemini-cli/core';
|
import { Config } from '@gemini-cli/core';
|
||||||
import { Part, PartListUnion } from '@google/genai';
|
import { Part, PartListUnion } from '@google/genai';
|
||||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||||
|
import { Dispatch, SetStateAction } from 'react';
|
||||||
|
|
||||||
// --- MOCKS ---
|
// --- MOCKS ---
|
||||||
const mockSendMessageStream = vi
|
const mockSendMessageStream = vi
|
||||||
|
@ -309,16 +310,41 @@ describe('useGeminiStream', () => {
|
||||||
|
|
||||||
const client = geminiClient || mockConfig.getGeminiClient();
|
const client = geminiClient || mockConfig.getGeminiClient();
|
||||||
|
|
||||||
const { result, rerender } = renderHook(() =>
|
const { result, rerender } = renderHook(
|
||||||
|
(props: {
|
||||||
|
client: any;
|
||||||
|
addItem: UseHistoryManagerReturn['addItem'];
|
||||||
|
setShowHelp: Dispatch<SetStateAction<boolean>>;
|
||||||
|
config: Config;
|
||||||
|
onDebugMessage: (message: string) => void;
|
||||||
|
handleSlashCommand: (
|
||||||
|
command: PartListUnion,
|
||||||
|
) =>
|
||||||
|
| import('./slashCommandProcessor.js').SlashCommandActionReturn
|
||||||
|
| boolean;
|
||||||
|
shellModeActive: boolean;
|
||||||
|
}) =>
|
||||||
useGeminiStream(
|
useGeminiStream(
|
||||||
client,
|
props.client,
|
||||||
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
props.addItem,
|
||||||
mockSetShowHelp,
|
props.setShowHelp,
|
||||||
mockConfig,
|
props.config,
|
||||||
mockOnDebugMessage,
|
props.onDebugMessage,
|
||||||
mockHandleSlashCommand,
|
props.handleSlashCommand,
|
||||||
false, // shellModeActive
|
props.shellModeActive,
|
||||||
),
|
),
|
||||||
|
{
|
||||||
|
initialProps: {
|
||||||
|
client,
|
||||||
|
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||||
|
setShowHelp: mockSetShowHelp,
|
||||||
|
config: mockConfig,
|
||||||
|
onDebugMessage: mockOnDebugMessage,
|
||||||
|
handleSlashCommand:
|
||||||
|
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
|
||||||
|
shellModeActive: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
);
|
);
|
||||||
return {
|
return {
|
||||||
result,
|
result,
|
||||||
|
@ -326,7 +352,6 @@ describe('useGeminiStream', () => {
|
||||||
mockMarkToolsAsSubmitted,
|
mockMarkToolsAsSubmitted,
|
||||||
mockSendMessageStream,
|
mockSendMessageStream,
|
||||||
client,
|
client,
|
||||||
// mockFilter removed
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -423,24 +448,29 @@ describe('useGeminiStream', () => {
|
||||||
} as TrackedCancelledToolCall,
|
} as TrackedCancelledToolCall,
|
||||||
];
|
];
|
||||||
|
|
||||||
const hookResult = await act(async () =>
|
|
||||||
renderTestHook(simplifiedToolCalls),
|
|
||||||
);
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
rerender,
|
||||||
mockMarkToolsAsSubmitted,
|
mockMarkToolsAsSubmitted,
|
||||||
mockSendMessageStream: localMockSendMessageStream,
|
mockSendMessageStream: localMockSendMessageStream,
|
||||||
} = hookResult!;
|
client,
|
||||||
|
} = renderTestHook(simplifiedToolCalls);
|
||||||
|
|
||||||
// It seems the initial render + effect run should be enough.
|
act(() => {
|
||||||
// If rerender was for a specific state change, it might still be needed.
|
rerender({
|
||||||
// For now, let's test if the initial effect run (covered by the first act) is sufficient.
|
client,
|
||||||
// If not, we can add back: await act(async () => { rerender({}); });
|
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||||
|
setShowHelp: mockSetShowHelp,
|
||||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']);
|
config: mockConfig,
|
||||||
|
onDebugMessage: mockOnDebugMessage,
|
||||||
|
handleSlashCommand:
|
||||||
|
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
|
||||||
|
shellModeActive: false,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
expect(localMockSendMessageStream).toHaveBeenCalledTimes(1);
|
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
|
||||||
|
expect(localMockSendMessageStream).toHaveBeenCalledTimes(0);
|
||||||
});
|
});
|
||||||
|
|
||||||
const expectedMergedResponse = mergePartListUnions([
|
const expectedMergedResponse = mergePartListUnions([
|
||||||
|
@ -479,12 +509,21 @@ describe('useGeminiStream', () => {
|
||||||
client,
|
client,
|
||||||
);
|
);
|
||||||
|
|
||||||
await act(async () => {
|
act(() => {
|
||||||
rerender({} as any);
|
rerender({
|
||||||
|
client,
|
||||||
|
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||||
|
setShowHelp: mockSetShowHelp,
|
||||||
|
config: mockConfig,
|
||||||
|
onDebugMessage: mockOnDebugMessage,
|
||||||
|
handleSlashCommand:
|
||||||
|
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
|
||||||
|
shellModeActive: false,
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
await waitFor(() => {
|
await waitFor(() => {
|
||||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
|
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
|
||||||
expect(client.addHistory).toHaveBeenCalledTimes(2);
|
expect(client.addHistory).toHaveBeenCalledTimes(2);
|
||||||
expect(client.addHistory).toHaveBeenCalledWith({
|
expect(client.addHistory).toHaveBeenCalledWith({
|
||||||
role: 'user',
|
role: 'user',
|
||||||
|
|
|
@ -83,12 +83,8 @@ export const useGeminiStream = (
|
||||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||||
const logger = useLogger();
|
const logger = useLogger();
|
||||||
|
|
||||||
const [
|
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
|
||||||
toolCalls,
|
useReactToolScheduler(
|
||||||
scheduleToolCalls,
|
|
||||||
cancelAllToolCalls,
|
|
||||||
markToolsAsSubmitted,
|
|
||||||
] = useReactToolScheduler(
|
|
||||||
(completedToolCallsFromScheduler) => {
|
(completedToolCallsFromScheduler) => {
|
||||||
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||||
if (completedToolCallsFromScheduler.length > 0) {
|
if (completedToolCallsFromScheduler.length > 0) {
|
||||||
|
@ -143,10 +139,15 @@ export const useGeminiStream = (
|
||||||
return StreamingState.Idle;
|
return StreamingState.Idle;
|
||||||
}, [isResponding, toolCalls]);
|
}, [isResponding, toolCalls]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (streamingState === StreamingState.Idle) {
|
||||||
|
abortControllerRef.current = null;
|
||||||
|
}
|
||||||
|
}, [streamingState]);
|
||||||
|
|
||||||
useInput((_input, key) => {
|
useInput((_input, key) => {
|
||||||
if (streamingState !== StreamingState.Idle && key.escape) {
|
if (streamingState !== StreamingState.Idle && key.escape) {
|
||||||
abortControllerRef.current?.abort();
|
abortControllerRef.current?.abort();
|
||||||
cancelAllToolCalls(); // Also cancel any pending/executing tool calls
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -191,7 +192,7 @@ export const useGeminiStream = (
|
||||||
name: toolName,
|
name: toolName,
|
||||||
args: toolArgs,
|
args: toolArgs,
|
||||||
};
|
};
|
||||||
scheduleToolCalls([toolCallRequest]);
|
scheduleToolCalls([toolCallRequest], abortSignal);
|
||||||
}
|
}
|
||||||
return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool
|
return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool
|
||||||
}
|
}
|
||||||
|
@ -330,9 +331,8 @@ export const useGeminiStream = (
|
||||||
userMessageTimestamp,
|
userMessageTimestamp,
|
||||||
);
|
);
|
||||||
setIsResponding(false);
|
setIsResponding(false);
|
||||||
cancelAllToolCalls();
|
|
||||||
},
|
},
|
||||||
[addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls],
|
[addItem, pendingHistoryItemRef, setPendingHistoryItem],
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleErrorEvent = useCallback(
|
const handleErrorEvent = useCallback(
|
||||||
|
@ -365,6 +365,7 @@ export const useGeminiStream = (
|
||||||
async (
|
async (
|
||||||
stream: AsyncIterable<GeminiEvent>,
|
stream: AsyncIterable<GeminiEvent>,
|
||||||
userMessageTimestamp: number,
|
userMessageTimestamp: number,
|
||||||
|
signal: AbortSignal,
|
||||||
): Promise<StreamProcessingStatus> => {
|
): Promise<StreamProcessingStatus> => {
|
||||||
let geminiMessageBuffer = '';
|
let geminiMessageBuffer = '';
|
||||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||||
|
@ -401,7 +402,7 @@ export const useGeminiStream = (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (toolCallRequests.length > 0) {
|
if (toolCallRequests.length > 0) {
|
||||||
scheduleToolCalls(toolCallRequests);
|
scheduleToolCalls(toolCallRequests, signal);
|
||||||
}
|
}
|
||||||
return StreamProcessingStatus.Completed;
|
return StreamProcessingStatus.Completed;
|
||||||
},
|
},
|
||||||
|
@ -453,6 +454,7 @@ export const useGeminiStream = (
|
||||||
const processingStatus = await processGeminiStreamEvents(
|
const processingStatus = await processGeminiStreamEvents(
|
||||||
stream,
|
stream,
|
||||||
userMessageTimestamp,
|
userMessageTimestamp,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
|
|
||||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||||
|
@ -476,7 +478,6 @@ export const useGeminiStream = (
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
abortControllerRef.current = null; // Always reset
|
|
||||||
setIsResponding(false);
|
setIsResponding(false);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
@ -32,8 +32,8 @@ import {
|
||||||
|
|
||||||
export type ScheduleFn = (
|
export type ScheduleFn = (
|
||||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||||
|
signal: AbortSignal,
|
||||||
) => void;
|
) => void;
|
||||||
export type CancelFn = (reason?: string) => void;
|
|
||||||
export type MarkToolsAsSubmittedFn = (callIds: string[]) => void;
|
export type MarkToolsAsSubmittedFn = (callIds: string[]) => void;
|
||||||
|
|
||||||
export type TrackedScheduledToolCall = ScheduledToolCall & {
|
export type TrackedScheduledToolCall = ScheduledToolCall & {
|
||||||
|
@ -69,7 +69,7 @@ export function useReactToolScheduler(
|
||||||
setPendingHistoryItem: React.Dispatch<
|
setPendingHistoryItem: React.Dispatch<
|
||||||
React.SetStateAction<HistoryItemWithoutId | null>
|
React.SetStateAction<HistoryItemWithoutId | null>
|
||||||
>,
|
>,
|
||||||
): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] {
|
): [TrackedToolCall[], ScheduleFn, MarkToolsAsSubmittedFn] {
|
||||||
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
|
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
|
||||||
TrackedToolCall[]
|
TrackedToolCall[]
|
||||||
>([]);
|
>([]);
|
||||||
|
@ -172,15 +172,11 @@ export function useReactToolScheduler(
|
||||||
);
|
);
|
||||||
|
|
||||||
const schedule: ScheduleFn = useCallback(
|
const schedule: ScheduleFn = useCallback(
|
||||||
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
|
async (
|
||||||
scheduler.schedule(request);
|
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||||
},
|
signal: AbortSignal,
|
||||||
[scheduler],
|
) => {
|
||||||
);
|
scheduler.schedule(request, signal);
|
||||||
|
|
||||||
const cancel: CancelFn = useCallback(
|
|
||||||
(reason: string = 'unspecified') => {
|
|
||||||
scheduler.cancelAll(reason);
|
|
||||||
},
|
},
|
||||||
[scheduler],
|
[scheduler],
|
||||||
);
|
);
|
||||||
|
@ -198,7 +194,7 @@ export function useReactToolScheduler(
|
||||||
[],
|
[],
|
||||||
);
|
);
|
||||||
|
|
||||||
return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted];
|
return [toolCallsForDisplay, schedule, markToolsAsSubmitted];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -137,7 +137,7 @@ describe('useReactToolScheduler in YOLO Mode', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request);
|
schedule(request, new AbortController().signal);
|
||||||
});
|
});
|
||||||
|
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
|
@ -290,7 +290,7 @@ describe('useReactToolScheduler', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request);
|
schedule(request, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
|
@ -337,7 +337,7 @@ describe('useReactToolScheduler', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request);
|
schedule(request, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
|
@ -374,7 +374,7 @@ describe('useReactToolScheduler', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request);
|
schedule(request, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
|
@ -410,7 +410,7 @@ describe('useReactToolScheduler', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request);
|
schedule(request, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
|
@ -451,7 +451,7 @@ describe('useReactToolScheduler', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request);
|
schedule(request, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
|
@ -507,7 +507,7 @@ describe('useReactToolScheduler', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request);
|
schedule(request, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
|
@ -579,7 +579,7 @@ describe('useReactToolScheduler', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request);
|
schedule(request, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
|
@ -634,102 +634,6 @@ describe('useReactToolScheduler', () => {
|
||||||
expect(result.current[0]).toEqual([]);
|
expect(result.current[0]).toEqual([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
it.skip('should cancel tool calls before execution (e.g. when status is scheduled)', async () => {
|
|
||||||
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
|
||||||
(mockTool.shouldConfirmExecute as Mock).mockResolvedValue(null);
|
|
||||||
(mockTool.execute as Mock).mockReturnValue(new Promise(() => {}));
|
|
||||||
|
|
||||||
const { result } = renderScheduler();
|
|
||||||
const schedule = result.current[1];
|
|
||||||
const cancel = result.current[2];
|
|
||||||
const request: ToolCallRequestInfo = {
|
|
||||||
callId: 'cancelCall',
|
|
||||||
name: 'mockTool',
|
|
||||||
args: {},
|
|
||||||
};
|
|
||||||
|
|
||||||
act(() => {
|
|
||||||
schedule(request);
|
|
||||||
});
|
|
||||||
await act(async () => {
|
|
||||||
await vi.runAllTimersAsync();
|
|
||||||
});
|
|
||||||
|
|
||||||
act(() => {
|
|
||||||
cancel();
|
|
||||||
});
|
|
||||||
await act(async () => {
|
|
||||||
await vi.runAllTimersAsync();
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(onComplete).toHaveBeenCalledWith([
|
|
||||||
expect.objectContaining({
|
|
||||||
status: 'cancelled',
|
|
||||||
request,
|
|
||||||
response: expect.objectContaining({
|
|
||||||
responseParts: expect.arrayContaining([
|
|
||||||
expect.objectContaining({
|
|
||||||
functionResponse: expect.objectContaining({
|
|
||||||
response: expect.objectContaining({
|
|
||||||
error:
|
|
||||||
'[Operation Cancelled] Reason: User cancelled before execution',
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
]);
|
|
||||||
expect(mockTool.execute).not.toHaveBeenCalled();
|
|
||||||
expect(result.current[0]).toEqual([]);
|
|
||||||
});
|
|
||||||
|
|
||||||
it.skip('should cancel tool calls that are awaiting approval', async () => {
|
|
||||||
mockToolRegistry.getTool.mockReturnValue(mockToolRequiresConfirmation);
|
|
||||||
const { result } = renderScheduler();
|
|
||||||
const schedule = result.current[1];
|
|
||||||
const cancelFn = result.current[2];
|
|
||||||
const request: ToolCallRequestInfo = {
|
|
||||||
callId: 'cancelApprovalCall',
|
|
||||||
name: 'mockToolRequiresConfirmation',
|
|
||||||
args: {},
|
|
||||||
};
|
|
||||||
|
|
||||||
act(() => {
|
|
||||||
schedule(request);
|
|
||||||
});
|
|
||||||
await act(async () => {
|
|
||||||
await vi.runAllTimersAsync();
|
|
||||||
});
|
|
||||||
|
|
||||||
act(() => {
|
|
||||||
cancelFn();
|
|
||||||
});
|
|
||||||
await act(async () => {
|
|
||||||
await vi.runAllTimersAsync();
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(onComplete).toHaveBeenCalledWith([
|
|
||||||
expect.objectContaining({
|
|
||||||
status: 'cancelled',
|
|
||||||
request,
|
|
||||||
response: expect.objectContaining({
|
|
||||||
responseParts: expect.arrayContaining([
|
|
||||||
expect.objectContaining({
|
|
||||||
functionResponse: expect.objectContaining({
|
|
||||||
response: expect.objectContaining({
|
|
||||||
error:
|
|
||||||
'[Operation Cancelled] Reason: User cancelled during approval',
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
]);
|
|
||||||
expect(result.current[0]).toEqual([]);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should schedule and execute multiple tool calls', async () => {
|
it('should schedule and execute multiple tool calls', async () => {
|
||||||
const tool1 = {
|
const tool1 = {
|
||||||
...mockTool,
|
...mockTool,
|
||||||
|
@ -766,7 +670,7 @@ describe('useReactToolScheduler', () => {
|
||||||
];
|
];
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(requests);
|
schedule(requests, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
|
@ -848,13 +752,13 @@ describe('useReactToolScheduler', () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
act(() => {
|
act(() => {
|
||||||
schedule(request1);
|
schedule(request1, new AbortController().signal);
|
||||||
});
|
});
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(() => schedule(request2)).toThrow(
|
expect(() => schedule(request2, new AbortController().signal)).toThrow(
|
||||||
'Cannot schedule tool calls while other tool calls are running',
|
'Cannot schedule tool calls while other tool calls are running',
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,110 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, it, expect } from 'vitest';
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { convertToFunctionResponse } from './coreToolScheduler.js';
|
import { describe, it, expect, vi } from 'vitest';
|
||||||
|
import {
|
||||||
|
CoreToolScheduler,
|
||||||
|
ToolCall,
|
||||||
|
ValidatingToolCall,
|
||||||
|
} from './coreToolScheduler.js';
|
||||||
|
import {
|
||||||
|
BaseTool,
|
||||||
|
ToolCallConfirmationDetails,
|
||||||
|
ToolConfirmationOutcome,
|
||||||
|
ToolResult,
|
||||||
|
} from '../index.js';
|
||||||
import { Part, PartListUnion } from '@google/genai';
|
import { Part, PartListUnion } from '@google/genai';
|
||||||
|
import { convertToFunctionResponse } from './coreToolScheduler.js';
|
||||||
|
|
||||||
|
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
|
||||||
|
shouldConfirm = false;
|
||||||
|
executeFn = vi.fn();
|
||||||
|
|
||||||
|
constructor(name = 'mockTool') {
|
||||||
|
super(name, name, 'A mock tool', {});
|
||||||
|
}
|
||||||
|
|
||||||
|
async shouldConfirmExecute(
|
||||||
|
_params: Record<string, unknown>,
|
||||||
|
_abortSignal: AbortSignal,
|
||||||
|
): Promise<ToolCallConfirmationDetails | false> {
|
||||||
|
if (this.shouldConfirm) {
|
||||||
|
return {
|
||||||
|
type: 'exec',
|
||||||
|
title: 'Confirm Mock Tool',
|
||||||
|
command: 'do_thing',
|
||||||
|
rootCommand: 'do_thing',
|
||||||
|
onConfirm: async () => {},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
async execute(
|
||||||
|
params: Record<string, unknown>,
|
||||||
|
_abortSignal: AbortSignal,
|
||||||
|
): Promise<ToolResult> {
|
||||||
|
this.executeFn(params);
|
||||||
|
return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('CoreToolScheduler', () => {
|
||||||
|
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
|
||||||
|
const mockTool = new MockTool();
|
||||||
|
mockTool.shouldConfirm = true;
|
||||||
|
const toolRegistry = {
|
||||||
|
getTool: () => mockTool,
|
||||||
|
getFunctionDeclarations: () => [],
|
||||||
|
tools: new Map(),
|
||||||
|
discovery: {} as any,
|
||||||
|
config: {} as any,
|
||||||
|
registerTool: () => {},
|
||||||
|
getToolByName: () => mockTool,
|
||||||
|
getToolByDisplayName: () => mockTool,
|
||||||
|
getTools: () => [],
|
||||||
|
discoverTools: async () => {},
|
||||||
|
getAllTools: () => [],
|
||||||
|
getToolsByServer: () => [],
|
||||||
|
};
|
||||||
|
|
||||||
|
const onAllToolCallsComplete = vi.fn();
|
||||||
|
const onToolCallsUpdate = vi.fn();
|
||||||
|
|
||||||
|
const scheduler = new CoreToolScheduler({
|
||||||
|
toolRegistry: Promise.resolve(toolRegistry as any),
|
||||||
|
onAllToolCallsComplete,
|
||||||
|
onToolCallsUpdate,
|
||||||
|
});
|
||||||
|
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const request = { callId: '1', name: 'mockTool', args: {} };
|
||||||
|
|
||||||
|
abortController.abort();
|
||||||
|
await scheduler.schedule([request], abortController.signal);
|
||||||
|
|
||||||
|
const _waitingCall = onToolCallsUpdate.mock
|
||||||
|
.calls[1][0][0] as ValidatingToolCall;
|
||||||
|
const confirmationDetails = await mockTool.shouldConfirmExecute(
|
||||||
|
{},
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
if (confirmationDetails) {
|
||||||
|
await scheduler.handleConfirmationResponse(
|
||||||
|
'1',
|
||||||
|
confirmationDetails.onConfirm,
|
||||||
|
ToolConfirmationOutcome.ProceedOnce,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||||
|
const completedCalls = onAllToolCallsComplete.mock
|
||||||
|
.calls[0][0] as ToolCall[];
|
||||||
|
expect(completedCalls[0].status).toBe('cancelled');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('convertToFunctionResponse', () => {
|
describe('convertToFunctionResponse', () => {
|
||||||
const toolName = 'testTool';
|
const toolName = 'testTool';
|
||||||
|
|
|
@ -208,7 +208,6 @@ interface CoreToolSchedulerOptions {
|
||||||
export class CoreToolScheduler {
|
export class CoreToolScheduler {
|
||||||
private toolRegistry: Promise<ToolRegistry>;
|
private toolRegistry: Promise<ToolRegistry>;
|
||||||
private toolCalls: ToolCall[] = [];
|
private toolCalls: ToolCall[] = [];
|
||||||
private abortController: AbortController;
|
|
||||||
private outputUpdateHandler?: OutputUpdateHandler;
|
private outputUpdateHandler?: OutputUpdateHandler;
|
||||||
private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
|
private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
|
||||||
private onToolCallsUpdate?: ToolCallsUpdateHandler;
|
private onToolCallsUpdate?: ToolCallsUpdateHandler;
|
||||||
|
@ -220,7 +219,6 @@ export class CoreToolScheduler {
|
||||||
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
|
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
|
||||||
this.onToolCallsUpdate = options.onToolCallsUpdate;
|
this.onToolCallsUpdate = options.onToolCallsUpdate;
|
||||||
this.approvalMode = options.approvalMode ?? ApprovalMode.DEFAULT;
|
this.approvalMode = options.approvalMode ?? ApprovalMode.DEFAULT;
|
||||||
this.abortController = new AbortController();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private setStatusInternal(
|
private setStatusInternal(
|
||||||
|
@ -379,6 +377,7 @@ export class CoreToolScheduler {
|
||||||
|
|
||||||
async schedule(
|
async schedule(
|
||||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||||
|
signal: AbortSignal,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
if (this.isRunning()) {
|
if (this.isRunning()) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
|
@ -426,7 +425,7 @@ export class CoreToolScheduler {
|
||||||
} else {
|
} else {
|
||||||
const confirmationDetails = await toolInstance.shouldConfirmExecute(
|
const confirmationDetails = await toolInstance.shouldConfirmExecute(
|
||||||
reqInfo.args,
|
reqInfo.args,
|
||||||
this.abortController.signal,
|
signal,
|
||||||
);
|
);
|
||||||
|
|
||||||
if (confirmationDetails) {
|
if (confirmationDetails) {
|
||||||
|
@ -438,6 +437,7 @@ export class CoreToolScheduler {
|
||||||
reqInfo.callId,
|
reqInfo.callId,
|
||||||
originalOnConfirm,
|
originalOnConfirm,
|
||||||
outcome,
|
outcome,
|
||||||
|
signal,
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
this.setStatusInternal(
|
this.setStatusInternal(
|
||||||
|
@ -460,7 +460,7 @@ export class CoreToolScheduler {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.attemptExecutionOfScheduledCalls();
|
this.attemptExecutionOfScheduledCalls(signal);
|
||||||
this.checkAndNotifyCompletion();
|
this.checkAndNotifyCompletion();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -468,6 +468,7 @@ export class CoreToolScheduler {
|
||||||
callId: string,
|
callId: string,
|
||||||
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
|
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
|
||||||
outcome: ToolConfirmationOutcome,
|
outcome: ToolConfirmationOutcome,
|
||||||
|
signal: AbortSignal,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const toolCall = this.toolCalls.find(
|
const toolCall = this.toolCalls.find(
|
||||||
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
||||||
|
@ -477,7 +478,7 @@ export class CoreToolScheduler {
|
||||||
await originalOnConfirm(outcome);
|
await originalOnConfirm(outcome);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (outcome === ToolConfirmationOutcome.Cancel) {
|
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
|
||||||
this.setStatusInternal(
|
this.setStatusInternal(
|
||||||
callId,
|
callId,
|
||||||
'cancelled',
|
'cancelled',
|
||||||
|
@ -497,7 +498,7 @@ export class CoreToolScheduler {
|
||||||
|
|
||||||
const modifyResults = await editTool.onModify(
|
const modifyResults = await editTool.onModify(
|
||||||
waitingToolCall.request.args as unknown as EditToolParams,
|
waitingToolCall.request.args as unknown as EditToolParams,
|
||||||
this.abortController.signal,
|
signal,
|
||||||
outcome,
|
outcome,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -513,10 +514,10 @@ export class CoreToolScheduler {
|
||||||
} else {
|
} else {
|
||||||
this.setStatusInternal(callId, 'scheduled');
|
this.setStatusInternal(callId, 'scheduled');
|
||||||
}
|
}
|
||||||
this.attemptExecutionOfScheduledCalls();
|
this.attemptExecutionOfScheduledCalls(signal);
|
||||||
}
|
}
|
||||||
|
|
||||||
private attemptExecutionOfScheduledCalls(): void {
|
private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
|
||||||
const allCallsFinalOrScheduled = this.toolCalls.every(
|
const allCallsFinalOrScheduled = this.toolCalls.every(
|
||||||
(call) =>
|
(call) =>
|
||||||
call.status === 'scheduled' ||
|
call.status === 'scheduled' ||
|
||||||
|
@ -553,17 +554,13 @@ export class CoreToolScheduler {
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
scheduledCall.tool
|
scheduledCall.tool
|
||||||
.execute(
|
.execute(scheduledCall.request.args, signal, liveOutputCallback)
|
||||||
scheduledCall.request.args,
|
|
||||||
this.abortController.signal,
|
|
||||||
liveOutputCallback,
|
|
||||||
)
|
|
||||||
.then((toolResult: ToolResult) => {
|
.then((toolResult: ToolResult) => {
|
||||||
if (this.abortController.signal.aborted) {
|
if (signal.aborted) {
|
||||||
this.setStatusInternal(
|
this.setStatusInternal(
|
||||||
callId,
|
callId,
|
||||||
'cancelled',
|
'cancelled',
|
||||||
this.abortController.signal.reason || 'Execution aborted.',
|
'User cancelled tool execution.',
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -613,29 +610,10 @@ export class CoreToolScheduler {
|
||||||
if (this.onAllToolCallsComplete) {
|
if (this.onAllToolCallsComplete) {
|
||||||
this.onAllToolCallsComplete(completedCalls);
|
this.onAllToolCallsComplete(completedCalls);
|
||||||
}
|
}
|
||||||
this.abortController = new AbortController();
|
|
||||||
this.notifyToolCallsUpdate();
|
this.notifyToolCallsUpdate();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cancelAll(reason: string = 'User initiated cancellation.'): void {
|
|
||||||
if (!this.abortController.signal.aborted) {
|
|
||||||
this.abortController.abort(reason);
|
|
||||||
}
|
|
||||||
this.abortController = new AbortController();
|
|
||||||
|
|
||||||
const callsToCancel = [...this.toolCalls];
|
|
||||||
callsToCancel.forEach((call) => {
|
|
||||||
if (
|
|
||||||
call.status !== 'error' &&
|
|
||||||
call.status !== 'success' &&
|
|
||||||
call.status !== 'cancelled'
|
|
||||||
) {
|
|
||||||
this.setStatusInternal(call.request.callId, 'cancelled', reason);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private notifyToolCallsUpdate(): void {
|
private notifyToolCallsUpdate(): void {
|
||||||
if (this.onToolCallsUpdate) {
|
if (this.onToolCallsUpdate) {
|
||||||
this.onToolCallsUpdate([...this.toolCalls]);
|
this.onToolCallsUpdate([...this.toolCalls]);
|
||||||
|
|
|
@ -162,6 +162,13 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (abortSignal.aborted) {
|
||||||
|
return {
|
||||||
|
llmContent: 'Command was cancelled by user before it could start.',
|
||||||
|
returnDisplay: 'Command cancelled by user.',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// wrap command to append subprocess pids (via pgrep) to temporary file
|
// wrap command to append subprocess pids (via pgrep) to temporary file
|
||||||
const tempFileName = `shell_pgrep_${crypto.randomBytes(6).toString('hex')}.tmp`;
|
const tempFileName = `shell_pgrep_${crypto.randomBytes(6).toString('hex')}.tmp`;
|
||||||
const tempFilePath = path.join(os.tmpdir(), tempFileName);
|
const tempFilePath = path.join(os.tmpdir(), tempFileName);
|
||||||
|
|
Loading…
Reference in New Issue