fix(cli): correctly handle tool invocation cancellation (#844)

This commit is contained in:
N. Taylor Mullen 2025-06-08 11:14:45 -07:00 committed by GitHub
parent 9efca40dae
commit 241c404573
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 156 additions and 4 deletions

View File

@ -30,6 +30,7 @@ const MockedGeminiClientClass = vi.hoisted(() =>
// _config // _config
this.startChat = mockStartChat; this.startChat = mockStartChat;
this.sendMessageStream = mockSendMessageStream; this.sendMessageStream = mockSendMessageStream;
this.addHistory = vi.fn();
}), }),
); );
@ -267,6 +268,7 @@ describe('useGeminiStream', () => {
() => ({ getToolSchemaList: vi.fn(() => []) }) as any, () => ({ getToolSchemaList: vi.fn(() => []) }) as any,
), ),
getGeminiClient: mockGetGeminiClient, getGeminiClient: mockGetGeminiClient,
addHistory: vi.fn(),
} as unknown as Config; } as unknown as Config;
mockOnDebugMessage = vi.fn(); mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockReturnValue(false); mockHandleSlashCommand = vi.fn().mockReturnValue(false);
@ -294,7 +296,10 @@ describe('useGeminiStream', () => {
.mockReturnValue((async function* () {})()); .mockReturnValue((async function* () {})());
}); });
const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => { const renderTestHook = (
initialToolCalls: TrackedToolCall[] = [],
geminiClient?: any,
) => {
mockUseReactToolScheduler.mockReturnValue([ mockUseReactToolScheduler.mockReturnValue([
initialToolCalls, initialToolCalls,
mockScheduleToolCalls, mockScheduleToolCalls,
@ -302,9 +307,11 @@ describe('useGeminiStream', () => {
mockMarkToolsAsSubmitted, mockMarkToolsAsSubmitted,
]); ]);
const client = geminiClient || mockConfig.getGeminiClient();
const { result, rerender } = renderHook(() => const { result, rerender } = renderHook(() =>
useGeminiStream( useGeminiStream(
mockConfig.getGeminiClient(), client,
mockAddItem as unknown as UseHistoryManagerReturn['addItem'], mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
mockSetShowHelp, mockSetShowHelp,
mockConfig, mockConfig,
@ -318,6 +325,7 @@ describe('useGeminiStream', () => {
rerender, rerender,
mockMarkToolsAsSubmitted, mockMarkToolsAsSubmitted,
mockSendMessageStream, mockSendMessageStream,
client,
// mockFilter removed // mockFilter removed
}; };
}; };
@ -444,4 +452,44 @@ describe('useGeminiStream', () => {
expect.any(AbortSignal), expect.any(AbortSignal),
); );
}); });
it('should handle all tool calls being cancelled', async () => {
const toolCalls: TrackedToolCall[] = [
{
request: { callId: '1', name: 'testTool', args: {} },
status: 'cancelled',
response: {
callId: '1',
responseParts: [{ text: 'cancelled' }],
error: undefined,
resultDisplay: 'Tool 1 cancelled display',
},
responseSubmittedToGemini: false,
tool: {
name: 'testTool',
description: 'desc',
getDescription: vi.fn(),
} as any,
},
];
const client = new MockedGeminiClientClass(mockConfig);
const { mockMarkToolsAsSubmitted, rerender } = renderTestHook(
toolCalls,
client,
);
await act(async () => {
rerender({} as any);
});
await waitFor(() => {
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
expect(client.addHistory).toHaveBeenCalledTimes(2);
expect(client.addHistory).toHaveBeenCalledWith({
role: 'user',
parts: [{ text: 'cancelled' }],
});
});
});
}); });

View File

@ -19,7 +19,7 @@ import {
ToolCallRequestInfo, ToolCallRequestInfo,
logUserPrompt, logUserPrompt,
} from '@gemini-cli/core'; } from '@gemini-cli/core';
import { type PartListUnion } from '@google/genai'; import { type Part, type PartListUnion } from '@google/genai';
import { import {
StreamingState, StreamingState,
HistoryItemWithoutId, HistoryItemWithoutId,
@ -531,6 +531,41 @@ export const useGeminiStream = (
completedAndReadyToSubmitTools.length > 0 && completedAndReadyToSubmitTools.length > 0 &&
completedAndReadyToSubmitTools.length === toolCalls.length completedAndReadyToSubmitTools.length === toolCalls.length
) { ) {
// If all the tools were cancelled, don't submit a response to Gemini.
const allToolsCancelled = completedAndReadyToSubmitTools.every(
(tc) => tc.status === 'cancelled',
);
if (allToolsCancelled) {
if (geminiClient) {
// We need to manually add the function responses to the history
// so the model knows the tools were cancelled.
const responsesToAdd = completedAndReadyToSubmitTools.flatMap(
(toolCall) => toolCall.response.responseParts,
);
for (const response of responsesToAdd) {
let parts: Part[];
if (Array.isArray(response)) {
parts = response;
} else if (typeof response === 'string') {
parts = [{ text: response }];
} else {
parts = [response];
}
geminiClient.addHistory({
role: 'user',
parts,
});
}
}
const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
(toolCall) => toolCall.request.callId,
);
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
return;
}
const responsesToSend: PartListUnion[] = const responsesToSend: PartListUnion[] =
completedAndReadyToSubmitTools.map( completedAndReadyToSubmitTools.map(
(toolCall) => toolCall.response.responseParts, (toolCall) => toolCall.response.responseParts,
@ -542,7 +577,14 @@ export const useGeminiStream = (
markToolsAsSubmitted(callIdsToMarkAsSubmitted); markToolsAsSubmitted(callIdsToMarkAsSubmitted);
submitQuery(mergePartListUnions(responsesToSend)); submitQuery(mergePartListUnions(responsesToSend));
} }
}, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]); }, [
toolCalls,
isResponding,
submitQuery,
markToolsAsSubmitted,
addItem,
geminiClient,
]);
const pendingHistoryItems = [ const pendingHistoryItems = [
pendingHistoryItemRef.current, pendingHistoryItemRef.current,

View File

@ -219,4 +219,22 @@ describe('Gemini Client (client.ts)', () => {
); );
}); });
}); });
describe('addHistory', () => {
it('should call chat.addHistory with the provided content', async () => {
const mockChat = {
addHistory: vi.fn(),
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
client['chat'] = Promise.resolve(mockChat as any);
const newContent = {
role: 'user',
parts: [{ text: 'New history item' }],
};
await client.addHistory(newContent);
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
});
});
}); });

View File

@ -58,6 +58,11 @@ export class GeminiClient {
this.chat = this.startChat(); this.chat = this.startChat();
} }
async addHistory(content: Content) {
const chat = await this.chat;
chat.addHistory(content);
}
getChat(): Promise<GeminiChat> { getChat(): Promise<GeminiChat> {
return this.chat; return this.chat;
} }

View File

@ -352,4 +352,34 @@ describe('GeminiChat', () => {
expect(history[1].parts).toEqual([{ text: 'Visible text' }]); expect(history[1].parts).toEqual([{ text: 'Visible text' }]);
}); });
}); });
describe('addHistory', () => {
it('should add a new content item to the history', () => {
const newContent: Content = {
role: 'user',
parts: [{ text: 'A new message' }],
};
chat.addHistory(newContent);
const history = chat.getHistory();
expect(history.length).toBe(1);
expect(history[0]).toEqual(newContent);
});
it('should add multiple items correctly', () => {
const content1: Content = {
role: 'user',
parts: [{ text: 'Message 1' }],
};
const content2: Content = {
role: 'model',
parts: [{ text: 'Message 2' }],
};
chat.addHistory(content1);
chat.addHistory(content2);
const history = chat.getHistory();
expect(history.length).toBe(2);
expect(history[0]).toEqual(content1);
expect(history[1]).toEqual(content2);
});
});
}); });

View File

@ -287,6 +287,15 @@ export class GeminiChat {
return structuredClone(history); return structuredClone(history);
} }
/**
* Adds a new entry to the chat history.
*
* @param content - The content to add to the history.
*/
addHistory(content: Content): void {
this.history.push(content);
}
private async *processStreamResponse( private async *processStreamResponse(
streamResponse: AsyncGenerator<GenerateContentResponse>, streamResponse: AsyncGenerator<GenerateContentResponse>,
inputContent: Content, inputContent: Content,