feat: Add client-initiated tool call handling (#1292)

This commit is contained in:
Abhi 2025-06-22 01:35:36 -04:00 committed by GitHub
parent 5cf8dc4f07
commit c9950b3cb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 363 additions and 136 deletions

View File

@ -89,6 +89,7 @@ export async function runNonInteractive(
callId, callId,
name: fc.name as string, name: fc.name as string,
args: (fc.args ?? {}) as Record<string, unknown>, args: (fc.args ?? {}) as Record<string, unknown>,
isClientInitiated: false,
}; };
const toolResponse = await executeToolCall( const toolResponse = await executeToolCall(

View File

@ -362,6 +362,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
shellModeActive, shellModeActive,
getPreferredEditor, getPreferredEditor,
onAuthError, onAuthError,
performMemoryRefresh,
); );
pendingHistoryItems.push(...pendingGeminiHistoryItems); pendingHistoryItems.push(...pendingGeminiHistoryItems);
const { elapsedTime, currentLoadingPhrase } = const { elapsedTime, currentLoadingPhrase } =

View File

@ -371,6 +371,7 @@ describe('useGeminiStream', () => {
props.shellModeActive, props.shellModeActive,
() => 'vscode' as EditorType, () => 'vscode' as EditorType,
() => {}, () => {},
() => Promise.resolve(),
); );
}, },
{ {
@ -389,6 +390,7 @@ describe('useGeminiStream', () => {
>, >,
shellModeActive: false, shellModeActive: false,
loadedSettings: mockLoadedSettings, loadedSettings: mockLoadedSettings,
toolCalls: initialToolCalls,
}, },
}, },
); );
@ -404,7 +406,12 @@ describe('useGeminiStream', () => {
it('should not submit tool responses if not all tool calls are completed', () => { it('should not submit tool responses if not all tool calls are completed', () => {
const toolCalls: TrackedToolCall[] = [ const toolCalls: TrackedToolCall[] = [
{ {
request: { callId: 'call1', name: 'tool1', args: {} }, request: {
callId: 'call1',
name: 'tool1',
args: {},
isClientInitiated: false,
},
status: 'success', status: 'success',
responseSubmittedToGemini: false, responseSubmittedToGemini: false,
response: { response: {
@ -452,133 +459,138 @@ describe('useGeminiStream', () => {
const toolCall2ResponseParts: PartListUnion = [ const toolCall2ResponseParts: PartListUnion = [
{ text: 'tool 2 final response' }, { text: 'tool 2 final response' },
]; ];
const completedToolCalls: TrackedToolCall[] = [
// Simplified toolCalls to ensure the filter logic is the focus
const simplifiedToolCalls: TrackedToolCall[] = [
{ {
request: { callId: 'call1', name: 'tool1', args: {} }, request: {
callId: 'call1',
name: 'tool1',
args: {},
isClientInitiated: false,
},
status: 'success', status: 'success',
responseSubmittedToGemini: false, responseSubmittedToGemini: false,
response: { response: { callId: 'call1', responseParts: toolCall1ResponseParts },
callId: 'call1',
responseParts: toolCall1ResponseParts,
error: undefined,
resultDisplay: 'Tool 1 success display',
},
tool: {
name: 'tool1',
description: 'desc',
getDescription: vi.fn(),
} as any,
startTime: Date.now(),
endTime: Date.now(),
} as TrackedCompletedToolCall, } as TrackedCompletedToolCall,
{ {
request: { callId: 'call2', name: 'tool2', args: {} }, request: {
status: 'cancelled',
responseSubmittedToGemini: false,
response: {
callId: 'call2', callId: 'call2',
responseParts: toolCall2ResponseParts,
error: undefined,
resultDisplay: 'Tool 2 cancelled display',
},
tool: {
name: 'tool2', name: 'tool2',
description: 'desc', args: {},
getDescription: vi.fn(), isClientInitiated: false,
} as any, },
startTime: Date.now(), status: 'error',
endTime: Date.now(), responseSubmittedToGemini: false,
reason: 'test cancellation', response: { callId: 'call2', responseParts: toolCall2ResponseParts },
} as TrackedCancelledToolCall, } as TrackedCompletedToolCall, // Treat error as a form of completion for submission
]; ];
const { // 1. On the first render, there are no tool calls.
rerender, mockUseReactToolScheduler.mockReturnValue([
[],
mockScheduleToolCalls,
mockMarkToolsAsSubmitted, mockMarkToolsAsSubmitted,
mockSendMessageStream: localMockSendMessageStream, ]);
client, const { rerender } = renderHook(() =>
} = renderTestHook(simplifiedToolCalls); useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockSetShowHelp,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
),
);
// 2. Before the second render, change the mock to return the completed tools.
mockUseReactToolScheduler.mockReturnValue([
completedToolCalls,
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
]);
// 3. Trigger a re-render. The hook will now receive the completed tools, causing the effect to run.
act(() => { act(() => {
rerender({ rerender();
client,
history: [],
addItem: mockAddItem,
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
shellModeActive: false,
loadedSettings: mockLoadedSettings,
});
}); });
await waitFor(() => { await waitFor(() => {
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0); expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(1);
expect(localMockSendMessageStream).toHaveBeenCalledTimes(0); expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
}); });
const expectedMergedResponse = mergePartListUnions([ const expectedMergedResponse = mergePartListUnions([
toolCall1ResponseParts, toolCall1ResponseParts,
toolCall2ResponseParts, toolCall2ResponseParts,
]); ]);
expect(localMockSendMessageStream).toHaveBeenCalledWith( expect(mockSendMessageStream).toHaveBeenCalledWith(
expectedMergedResponse, expectedMergedResponse,
expect.any(AbortSignal), expect.any(AbortSignal),
); );
}); });
it('should handle all tool calls being cancelled', async () => { it('should handle all tool calls being cancelled', async () => {
const toolCalls: TrackedToolCall[] = [ const cancelledToolCalls: TrackedToolCall[] = [
{ {
request: { callId: '1', name: 'testTool', args: {} }, request: {
status: 'cancelled',
response: {
callId: '1', callId: '1',
responseParts: [{ text: 'cancelled' }],
error: undefined,
resultDisplay: 'Tool 1 cancelled display',
},
responseSubmittedToGemini: false,
tool: {
name: 'testTool', name: 'testTool',
description: 'desc', args: {},
getDescription: vi.fn(), isClientInitiated: false,
} as any, },
}, status: 'cancelled',
response: { callId: '1', responseParts: [{ text: 'cancelled' }] },
responseSubmittedToGemini: false,
} as TrackedCancelledToolCall,
]; ];
const client = new MockedGeminiClientClass(mockConfig); const client = new MockedGeminiClientClass(mockConfig);
const { mockMarkToolsAsSubmitted, rerender } = renderTestHook(
toolCalls, // 1. First render: no tool calls.
client, mockUseReactToolScheduler.mockReturnValue([
[],
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
]);
const { rerender } = renderHook(() =>
useGeminiStream(
client,
[],
mockAddItem,
mockSetShowHelp,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
),
); );
// 2. Second render: tool calls are now cancelled.
mockUseReactToolScheduler.mockReturnValue([
cancelledToolCalls,
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
]);
// 3. Trigger the re-render.
act(() => { act(() => {
rerender({ rerender();
client,
history: [],
addItem: mockAddItem,
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
shellModeActive: false,
loadedSettings: mockLoadedSettings,
});
}); });
await waitFor(() => { await waitFor(() => {
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0); expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
expect(client.addHistory).toHaveBeenCalledTimes(2);
expect(client.addHistory).toHaveBeenCalledWith({ expect(client.addHistory).toHaveBeenCalledWith({
role: 'user', role: 'user',
parts: [{ text: 'cancelled' }], parts: [{ text: 'cancelled' }],
}); });
// Ensure we do NOT call back to the API
expect(mockSendMessageStream).not.toHaveBeenCalled();
}); });
}); });
@ -708,7 +720,6 @@ describe('useGeminiStream', () => {
loadedSettings: mockLoadedSettings, loadedSettings: mockLoadedSettings,
// This is the key part of the test: update the toolCalls array // This is the key part of the test: update the toolCalls array
// to simulate the tool finishing. // to simulate the tool finishing.
// @ts-expect-error - we are adding a property to the props object
toolCalls: completedToolCalls, toolCalls: completedToolCalls,
}); });
}); });
@ -874,4 +885,145 @@ describe('useGeminiStream', () => {
expect(abortSpy).not.toHaveBeenCalled(); expect(abortSpy).not.toHaveBeenCalled();
}); });
}); });
describe('Client-Initiated Tool Calls', () => {
it('should execute a client-initiated tool without sending a response to Gemini', async () => {
const clientToolRequest = {
shouldScheduleTool: true,
toolName: 'save_memory',
toolArgs: { fact: 'test fact' },
};
mockHandleSlashCommand.mockResolvedValue(clientToolRequest);
const completedToolCall: TrackedCompletedToolCall = {
request: {
callId: 'client-call-1',
name: clientToolRequest.toolName,
args: clientToolRequest.toolArgs,
isClientInitiated: true,
},
status: 'success',
responseSubmittedToGemini: false,
response: {
callId: 'client-call-1',
responseParts: [{ text: 'Memory saved' }],
resultDisplay: 'Success: Memory saved',
error: undefined,
},
tool: {
name: clientToolRequest.toolName,
description: 'Saves memory',
getDescription: vi.fn(),
} as any,
};
// 1. Initial render state: no tool calls
mockUseReactToolScheduler.mockReturnValue([
[],
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
]);
const { result, rerender } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockSetShowHelp,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
() => Promise.resolve(),
),
);
// --- User runs the slash command ---
await act(async () => {
await result.current.submitQuery('/memory add "test fact"');
});
// The command handler schedules the tool. Now we simulate the tool completing.
// 2. Before the next render, set the mock to return the completed tool.
mockUseReactToolScheduler.mockReturnValue([
[completedToolCall],
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
]);
// 3. Trigger a re-render to process the completed tool.
act(() => {
rerender();
});
// --- Assert the outcome ---
await waitFor(() => {
// The tool should be marked as submitted locally
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith([
'client-call-1',
]);
// Crucially, no message should be sent to the Gemini API
expect(mockSendMessageStream).not.toHaveBeenCalled();
});
});
});
describe('Memory Refresh on save_memory', () => {
it('should call performMemoryRefresh when a save_memory tool call completes successfully', async () => {
const mockPerformMemoryRefresh = vi.fn();
const completedToolCall: TrackedCompletedToolCall = {
request: {
callId: 'save-mem-call-1',
name: 'save_memory',
args: { fact: 'test' },
isClientInitiated: true,
},
status: 'success',
responseSubmittedToGemini: false,
response: {
callId: 'save-mem-call-1',
responseParts: [{ text: 'Memory saved' }],
resultDisplay: 'Success: Memory saved',
error: undefined,
},
tool: {
name: 'save_memory',
description: 'Saves memory',
getDescription: vi.fn(),
} as any,
};
mockUseReactToolScheduler.mockReturnValue([
[completedToolCall],
mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
]);
const { rerender } = renderHook(() =>
useGeminiStream(
new MockedGeminiClientClass(mockConfig),
[],
mockAddItem,
mockSetShowHelp,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false,
() => 'vscode' as EditorType,
() => {},
mockPerformMemoryRefresh,
),
);
act(() => {
rerender();
});
await waitFor(() => {
expect(mockPerformMemoryRefresh).toHaveBeenCalledTimes(1);
});
});
});
}); });

View File

@ -89,6 +89,7 @@ export const useGeminiStream = (
shellModeActive: boolean, shellModeActive: boolean,
getPreferredEditor: () => EditorType | undefined, getPreferredEditor: () => EditorType | undefined,
onAuthError: () => void, onAuthError: () => void,
performMemoryRefresh: () => Promise<void>,
) => { ) => {
const [initError, setInitError] = useState<string | null>(null); const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null); const abortControllerRef = useRef<AbortController | null>(null);
@ -97,6 +98,7 @@ export const useGeminiStream = (
const [thought, setThought] = useState<ThoughtSummary | null>(null); const [thought, setThought] = useState<ThoughtSummary | null>(null);
const [pendingHistoryItemRef, setPendingHistoryItem] = const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null); useStateAndRef<HistoryItemWithoutId | null>(null);
const processedMemoryToolsRef = useRef<Set<string>>(new Set());
const logger = useLogger(); const logger = useLogger();
const { startNewTurn, addUsage } = useSessionStats(); const { startNewTurn, addUsage } = useSessionStats();
const gitService = useMemo(() => { const gitService = useMemo(() => {
@ -234,6 +236,7 @@ export const useGeminiStream = (
callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`, callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`,
name: toolName, name: toolName,
args: toolArgs, args: toolArgs,
isClientInitiated: true,
}; };
scheduleToolCalls([toolCallRequest], abortSignal); scheduleToolCalls([toolCallRequest], abortSignal);
} }
@ -566,38 +569,77 @@ export const useGeminiStream = (
* is not already generating a response. * is not already generating a response.
*/ */
useEffect(() => { useEffect(() => {
if (isResponding) { const run = async () => {
return; if (isResponding) {
} return;
}
const completedAndReadyToSubmitTools = toolCalls.filter( const completedAndReadyToSubmitTools = toolCalls.filter(
( (
tc: TrackedToolCall, tc: TrackedToolCall,
): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => { ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => {
const isTerminalState = const isTerminalState =
tc.status === 'success' || tc.status === 'success' ||
tc.status === 'error' || tc.status === 'error' ||
tc.status === 'cancelled'; tc.status === 'cancelled';
if (isTerminalState) { if (isTerminalState) {
const completedOrCancelledCall = tc as const completedOrCancelledCall = tc as
| TrackedCompletedToolCall | TrackedCompletedToolCall
| TrackedCancelledToolCall; | TrackedCancelledToolCall;
return ( return (
!completedOrCancelledCall.responseSubmittedToGemini && !completedOrCancelledCall.responseSubmittedToGemini &&
completedOrCancelledCall.response?.responseParts !== undefined completedOrCancelledCall.response?.responseParts !== undefined
); );
} }
return false; return false;
}, },
); );
// Finalize any client-initiated tools as soon as they are done.
const clientTools = completedAndReadyToSubmitTools.filter(
(t) => t.request.isClientInitiated,
);
if (clientTools.length > 0) {
markToolsAsSubmitted(clientTools.map((t) => t.request.callId));
}
// Identify new, successful save_memory calls that we haven't processed yet.
const newSuccessfulMemorySaves = completedAndReadyToSubmitTools.filter(
(t) =>
t.request.name === 'save_memory' &&
t.status === 'success' &&
!processedMemoryToolsRef.current.has(t.request.callId),
);
if (newSuccessfulMemorySaves.length > 0) {
// Perform the refresh only if there are new ones.
void performMemoryRefresh();
// Mark them as processed so we don't do this again on the next render.
newSuccessfulMemorySaves.forEach((t) =>
processedMemoryToolsRef.current.add(t.request.callId),
);
}
// Only proceed with submitting to Gemini if ALL tools are complete.
const allToolsAreComplete =
toolCalls.length > 0 &&
toolCalls.length === completedAndReadyToSubmitTools.length;
if (!allToolsAreComplete) {
return;
}
const geminiTools = completedAndReadyToSubmitTools.filter(
(t) => !t.request.isClientInitiated,
);
if (geminiTools.length === 0) {
return;
}
if (
completedAndReadyToSubmitTools.length > 0 &&
completedAndReadyToSubmitTools.length === toolCalls.length
) {
// If all the tools were cancelled, don't submit a response to Gemini. // If all the tools were cancelled, don't submit a response to Gemini.
const allToolsCancelled = completedAndReadyToSubmitTools.every( const allToolsCancelled = geminiTools.every(
(tc) => tc.status === 'cancelled', (tc) => tc.status === 'cancelled',
); );
@ -605,7 +647,7 @@ export const useGeminiStream = (
if (geminiClient) { if (geminiClient) {
// We need to manually add the function responses to the history // We need to manually add the function responses to the history
// so the model knows the tools were cancelled. // so the model knows the tools were cancelled.
const responsesToAdd = completedAndReadyToSubmitTools.flatMap( const responsesToAdd = geminiTools.flatMap(
(toolCall) => toolCall.response.responseParts, (toolCall) => toolCall.response.responseParts,
); );
for (const response of responsesToAdd) { for (const response of responsesToAdd) {
@ -624,18 +666,17 @@ export const useGeminiStream = (
} }
} }
const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( const callIdsToMarkAsSubmitted = geminiTools.map(
(toolCall) => toolCall.request.callId, (toolCall) => toolCall.request.callId,
); );
markToolsAsSubmitted(callIdsToMarkAsSubmitted); markToolsAsSubmitted(callIdsToMarkAsSubmitted);
return; return;
} }
const responsesToSend: PartListUnion[] = const responsesToSend: PartListUnion[] = geminiTools.map(
completedAndReadyToSubmitTools.map( (toolCall) => toolCall.response.responseParts,
(toolCall) => toolCall.response.responseParts, );
); const callIdsToMarkAsSubmitted = geminiTools.map(
const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
(toolCall) => toolCall.request.callId, (toolCall) => toolCall.request.callId,
); );
@ -643,7 +684,8 @@ export const useGeminiStream = (
submitQuery(mergePartListUnions(responsesToSend), { submitQuery(mergePartListUnions(responsesToSend), {
isContinuation: true, isContinuation: true,
}); });
} };
void run();
}, [ }, [
toolCalls, toolCalls,
isResponding, isResponding,
@ -651,6 +693,7 @@ export const useGeminiStream = (
markToolsAsSubmitted, markToolsAsSubmitted,
addItem, addItem,
geminiClient, geminiClient,
performMemoryRefresh,
]); ]);
const pendingHistoryItems = [ const pendingHistoryItems = [

View File

@ -88,7 +88,12 @@ describe('CoreToolScheduler', () => {
}); });
const abortController = new AbortController(); const abortController = new AbortController();
const request = { callId: '1', name: 'mockTool', args: {} }; const request = {
callId: '1',
name: 'mockTool',
args: {},
isClientInitiated: false,
};
abortController.abort(); abortController.abort();
await scheduler.schedule([request], abortController.signal); await scheduler.schedule([request], abortController.signal);

View File

@ -62,6 +62,7 @@ describe('executeToolCall', () => {
callId: 'call1', callId: 'call1',
name: 'testTool', name: 'testTool',
args: { param1: 'value1' }, args: { param1: 'value1' },
isClientInitiated: false,
}; };
const toolResult: ToolResult = { const toolResult: ToolResult = {
llmContent: 'Tool executed successfully', llmContent: 'Tool executed successfully',
@ -99,6 +100,7 @@ describe('executeToolCall', () => {
callId: 'call2', callId: 'call2',
name: 'nonExistentTool', name: 'nonExistentTool',
args: {}, args: {},
isClientInitiated: false,
}; };
vi.mocked(mockToolRegistry.getTool).mockReturnValue(undefined); vi.mocked(mockToolRegistry.getTool).mockReturnValue(undefined);
@ -133,6 +135,7 @@ describe('executeToolCall', () => {
callId: 'call3', callId: 'call3',
name: 'testTool', name: 'testTool',
args: { param1: 'value1' }, args: { param1: 'value1' },
isClientInitiated: false,
}; };
const executionError = new Error('Tool execution failed'); const executionError = new Error('Tool execution failed');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
@ -164,6 +167,7 @@ describe('executeToolCall', () => {
callId: 'call4', callId: 'call4',
name: 'testTool', name: 'testTool',
args: { param1: 'value1' }, args: { param1: 'value1' },
isClientInitiated: false,
}; };
const cancellationError = new Error('Operation cancelled'); const cancellationError = new Error('Operation cancelled');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool); vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
@ -206,6 +210,7 @@ describe('executeToolCall', () => {
callId: 'call5', callId: 'call5',
name: 'testTool', name: 'testTool',
args: {}, args: {},
isClientInitiated: false,
}; };
const imageDataPart: Part = { const imageDataPart: Part = {
inlineData: { mimeType: 'image/png', data: 'base64data' }, inlineData: { mimeType: 'image/png', data: 'base64data' },

View File

@ -132,8 +132,13 @@ describe('Turn', () => {
const mockResponseStream = (async function* () { const mockResponseStream = (async function* () {
yield { yield {
functionCalls: [ functionCalls: [
{ id: 'fc1', name: 'tool1', args: { arg1: 'val1' } }, {
{ name: 'tool2', args: { arg2: 'val2' } }, // No ID id: 'fc1',
name: 'tool1',
args: { arg1: 'val1' },
isClientInitiated: false,
},
{ name: 'tool2', args: { arg2: 'val2' }, isClientInitiated: false }, // No ID
], ],
} as unknown as GenerateContentResponse; } as unknown as GenerateContentResponse;
})(); })();
@ -156,6 +161,7 @@ describe('Turn', () => {
callId: 'fc1', callId: 'fc1',
name: 'tool1', name: 'tool1',
args: { arg1: 'val1' }, args: { arg1: 'val1' },
isClientInitiated: false,
}), }),
); );
expect(turn.pendingToolCalls[0]).toEqual(event1.value); expect(turn.pendingToolCalls[0]).toEqual(event1.value);
@ -163,7 +169,11 @@ describe('Turn', () => {
const event2 = events[1] as ServerGeminiToolCallRequestEvent; const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest); expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
expect(event2.value).toEqual( expect(event2.value).toEqual(
expect.objectContaining({ name: 'tool2', args: { arg2: 'val2' } }), expect.objectContaining({
name: 'tool2',
args: { arg2: 'val2' },
isClientInitiated: false,
}),
); );
expect(event2.value.callId).toEqual( expect(event2.value.callId).toEqual(
expect.stringMatching(/^tool2-\d{13}-\w{10,}$/), expect.stringMatching(/^tool2-\d{13}-\w{10,}$/),
@ -301,6 +311,7 @@ describe('Turn', () => {
callId: 'fc1', callId: 'fc1',
name: 'undefined_tool_name', name: 'undefined_tool_name',
args: { arg1: 'val1' }, args: { arg1: 'val1' },
isClientInitiated: false,
}), }),
); );
expect(turn.pendingToolCalls[0]).toEqual(event1.value); expect(turn.pendingToolCalls[0]).toEqual(event1.value);
@ -308,7 +319,12 @@ describe('Turn', () => {
const event2 = events[1] as ServerGeminiToolCallRequestEvent; const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest); expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
expect(event2.value).toEqual( expect(event2.value).toEqual(
expect.objectContaining({ callId: 'fc2', name: 'tool2', args: {} }), expect.objectContaining({
callId: 'fc2',
name: 'tool2',
args: {},
isClientInitiated: false,
}),
); );
expect(turn.pendingToolCalls[1]).toEqual(event2.value); expect(turn.pendingToolCalls[1]).toEqual(event2.value);
@ -319,6 +335,7 @@ describe('Turn', () => {
callId: 'fc3', callId: 'fc3',
name: 'undefined_tool_name', name: 'undefined_tool_name',
args: {}, args: {},
isClientInitiated: false,
}), }),
); );
expect(turn.pendingToolCalls[2]).toEqual(event3.value); expect(turn.pendingToolCalls[2]).toEqual(event3.value);

View File

@ -57,6 +57,7 @@ export interface ToolCallRequestInfo {
callId: string; callId: string;
name: string; name: string;
args: Record<string, unknown>; args: Record<string, unknown>;
isClientInitiated: boolean;
} }
export interface ToolCallResponseInfo { export interface ToolCallResponseInfo {
@ -139,11 +140,7 @@ export type ServerGeminiStreamEvent =
// A turn manages the agentic loop turn within the server context. // A turn manages the agentic loop turn within the server context.
export class Turn { export class Turn {
readonly pendingToolCalls: Array<{ readonly pendingToolCalls: ToolCallRequestInfo[];
callId: string;
name: string;
args: Record<string, unknown>;
}>;
private debugResponses: GenerateContentResponse[]; private debugResponses: GenerateContentResponse[];
private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null; private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
@ -254,11 +251,17 @@ export class Turn {
const name = fnCall.name || 'undefined_tool_name'; const name = fnCall.name || 'undefined_tool_name';
const args = (fnCall.args || {}) as Record<string, unknown>; const args = (fnCall.args || {}) as Record<string, unknown>;
this.pendingToolCalls.push({ callId, name, args }); const toolCallRequest: ToolCallRequestInfo = {
callId,
name,
args,
isClientInitiated: false,
};
this.pendingToolCalls.push(toolCallRequest);
// Yield a request for the tool call, not the pending/confirming status // Yield a request for the tool call, not the pending/confirming status
const value: ToolCallRequestInfo = { callId, name, args }; return { type: GeminiEventType.ToolCallRequest, value: toolCallRequest };
return { type: GeminiEventType.ToolCallRequest, value };
} }
getDebugResponses(): GenerateContentResponse[] { getDebugResponses(): GenerateContentResponse[] {