From 241c404573a8dd8c032dde5478a9bec95dd83a19 Mon Sep 17 00:00:00 2001 From: "N. Taylor Mullen" Date: Sun, 8 Jun 2025 11:14:45 -0700 Subject: [PATCH] fix(cli): correctly handle tool invocation cancellation (#844) --- .../cli/src/ui/hooks/useGeminiStream.test.tsx | 52 ++++++++++++++++++- packages/cli/src/ui/hooks/useGeminiStream.ts | 46 +++++++++++++++- packages/core/src/core/client.test.ts | 18 +++++++ packages/core/src/core/client.ts | 5 ++ packages/core/src/core/geminiChat.test.ts | 30 +++++++++++ packages/core/src/core/geminiChat.ts | 9 ++++ 6 files changed, 156 insertions(+), 4 deletions(-) diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index bd0f0520..1335eb8e 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -30,6 +30,7 @@ const MockedGeminiClientClass = vi.hoisted(() => // _config this.startChat = mockStartChat; this.sendMessageStream = mockSendMessageStream; + this.addHistory = vi.fn(); }), ); @@ -267,6 +268,7 @@ describe('useGeminiStream', () => { () => ({ getToolSchemaList: vi.fn(() => []) }) as any, ), getGeminiClient: mockGetGeminiClient, + addHistory: vi.fn(), } as unknown as Config; mockOnDebugMessage = vi.fn(); mockHandleSlashCommand = vi.fn().mockReturnValue(false); @@ -294,7 +296,10 @@ describe('useGeminiStream', () => { .mockReturnValue((async function* () {})()); }); - const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => { + const renderTestHook = ( + initialToolCalls: TrackedToolCall[] = [], + geminiClient?: any, + ) => { mockUseReactToolScheduler.mockReturnValue([ initialToolCalls, mockScheduleToolCalls, @@ -302,9 +307,11 @@ describe('useGeminiStream', () => { mockMarkToolsAsSubmitted, ]); + const client = geminiClient || mockConfig.getGeminiClient(); + const { result, rerender } = renderHook(() => useGeminiStream( - mockConfig.getGeminiClient(), + client, mockAddItem as unknown as UseHistoryManagerReturn['addItem'], mockSetShowHelp, mockConfig, @@ -318,6 +325,7 @@ describe('useGeminiStream', () => { rerender, mockMarkToolsAsSubmitted, mockSendMessageStream, + client, // mockFilter removed }; }; @@ -444,4 +452,44 @@ describe('useGeminiStream', () => { expect.any(AbortSignal), ); }); + + it('should handle all tool calls being cancelled', async () => { + const toolCalls: TrackedToolCall[] = [ + { + request: { callId: '1', name: 'testTool', args: {} }, + status: 'cancelled', + response: { + callId: '1', + responseParts: [{ text: 'cancelled' }], + error: undefined, + resultDisplay: 'Tool 1 cancelled display', + }, + responseSubmittedToGemini: false, + tool: { + name: 'testTool', + description: 'desc', + getDescription: vi.fn(), + } as any, + }, + ]; + + const client = new MockedGeminiClientClass(mockConfig); + const { mockMarkToolsAsSubmitted, rerender } = renderTestHook( + toolCalls, + client, + ); + + await act(async () => { + rerender({} as any); + }); + + await waitFor(() => { + expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']); + expect(client.addHistory).toHaveBeenCalledTimes(2); + expect(client.addHistory).toHaveBeenCalledWith({ + role: 'user', + parts: [{ text: 'cancelled' }], + }); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 5e741547..3b3d01e0 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -19,7 +19,7 @@ import { ToolCallRequestInfo, logUserPrompt, } from '@gemini-cli/core'; -import { type PartListUnion } from '@google/genai'; +import { type Part, type PartListUnion } from '@google/genai'; import { StreamingState, HistoryItemWithoutId, @@ -531,6 +531,41 @@ export const useGeminiStream = ( completedAndReadyToSubmitTools.length > 0 && completedAndReadyToSubmitTools.length === toolCalls.length ) { + // If all the tools were cancelled, don't submit a response to Gemini. + const allToolsCancelled = completedAndReadyToSubmitTools.every( + (tc) => tc.status === 'cancelled', + ); + + if (allToolsCancelled) { + if (geminiClient) { + // We need to manually add the function responses to the history + // so the model knows the tools were cancelled. + const responsesToAdd = completedAndReadyToSubmitTools.flatMap( + (toolCall) => toolCall.response.responseParts, + ); + for (const response of responsesToAdd) { + let parts: Part[]; + if (Array.isArray(response)) { + parts = response; + } else if (typeof response === 'string') { + parts = [{ text: response }]; + } else { + parts = [response]; + } + geminiClient.addHistory({ + role: 'user', + parts, + }); + } + } + + const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( + (toolCall) => toolCall.request.callId, + ); + markToolsAsSubmitted(callIdsToMarkAsSubmitted); + return; + } + const responsesToSend: PartListUnion[] = completedAndReadyToSubmitTools.map( (toolCall) => toolCall.response.responseParts, @@ -542,7 +577,14 @@ export const useGeminiStream = ( markToolsAsSubmitted(callIdsToMarkAsSubmitted); submitQuery(mergePartListUnions(responsesToSend)); } - }, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]); + }, [ + toolCalls, + isResponding, + submitQuery, + markToolsAsSubmitted, + addItem, + geminiClient, + ]); const pendingHistoryItems = [ pendingHistoryItemRef.current, diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 180d74bb..cbbbd113 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -219,4 +219,22 @@ describe('Gemini Client (client.ts)', () => { ); }); }); + + describe('addHistory', () => { + it('should call chat.addHistory with the provided content', async () => { + const mockChat = { + addHistory: vi.fn(), + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + client['chat'] = Promise.resolve(mockChat as any); + + const newContent = { + role: 'user', + parts: [{ text: 'New history item' }], + }; + await client.addHistory(newContent); + + expect(mockChat.addHistory).toHaveBeenCalledWith(newContent); + }); + }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 0f2a1b8a..8b921ab1 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -58,6 +58,11 @@ export class GeminiClient { this.chat = this.startChat(); } + async addHistory(content: Content) { + const chat = await this.chat; + chat.addHistory(content); + } + getChat(): Promise { return this.chat; } diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 6d18ebd9..dbed31b1 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -352,4 +352,34 @@ describe('GeminiChat', () => { expect(history[1].parts).toEqual([{ text: 'Visible text' }]); }); }); + + describe('addHistory', () => { + it('should add a new content item to the history', () => { + const newContent: Content = { + role: 'user', + parts: [{ text: 'A new message' }], + }; + chat.addHistory(newContent); + const history = chat.getHistory(); + expect(history.length).toBe(1); + expect(history[0]).toEqual(newContent); + }); + + it('should add multiple items correctly', () => { + const content1: Content = { + role: 'user', + parts: [{ text: 'Message 1' }], + }; + const content2: Content = { + role: 'model', + parts: [{ text: 'Message 2' }], + }; + chat.addHistory(content1); + chat.addHistory(content2); + const history = chat.getHistory(); + expect(history.length).toBe(2); + expect(history[0]).toEqual(content1); + expect(history[1]).toEqual(content2); + }); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 54f74102..47f3f3a6 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -287,6 +287,15 @@ export class GeminiChat { return structuredClone(history); } + /** + * Adds a new entry to the chat history. + * + * @param content - The content to add to the history. + */ + addHistory(content: Content): void { + this.history.push(content); + } + private async *processStreamResponse( streamResponse: AsyncGenerator, inputContent: Content,