fix(cli): correctly handle tool invocation cancellation (#844)
This commit is contained in:
parent
9efca40dae
commit
241c404573
|
@ -30,6 +30,7 @@ const MockedGeminiClientClass = vi.hoisted(() =>
|
|||
// _config
|
||||
this.startChat = mockStartChat;
|
||||
this.sendMessageStream = mockSendMessageStream;
|
||||
this.addHistory = vi.fn();
|
||||
}),
|
||||
);
|
||||
|
||||
|
@ -267,6 +268,7 @@ describe('useGeminiStream', () => {
|
|||
() => ({ getToolSchemaList: vi.fn(() => []) }) as any,
|
||||
),
|
||||
getGeminiClient: mockGetGeminiClient,
|
||||
addHistory: vi.fn(),
|
||||
} as unknown as Config;
|
||||
mockOnDebugMessage = vi.fn();
|
||||
mockHandleSlashCommand = vi.fn().mockReturnValue(false);
|
||||
|
@ -294,7 +296,10 @@ describe('useGeminiStream', () => {
|
|||
.mockReturnValue((async function* () {})());
|
||||
});
|
||||
|
||||
const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => {
|
||||
const renderTestHook = (
|
||||
initialToolCalls: TrackedToolCall[] = [],
|
||||
geminiClient?: any,
|
||||
) => {
|
||||
mockUseReactToolScheduler.mockReturnValue([
|
||||
initialToolCalls,
|
||||
mockScheduleToolCalls,
|
||||
|
@ -302,9 +307,11 @@ describe('useGeminiStream', () => {
|
|||
mockMarkToolsAsSubmitted,
|
||||
]);
|
||||
|
||||
const client = geminiClient || mockConfig.getGeminiClient();
|
||||
|
||||
const { result, rerender } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
mockConfig.getGeminiClient(),
|
||||
client,
|
||||
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
mockSetShowHelp,
|
||||
mockConfig,
|
||||
|
@ -318,6 +325,7 @@ describe('useGeminiStream', () => {
|
|||
rerender,
|
||||
mockMarkToolsAsSubmitted,
|
||||
mockSendMessageStream,
|
||||
client,
|
||||
// mockFilter removed
|
||||
};
|
||||
};
|
||||
|
@ -444,4 +452,44 @@ describe('useGeminiStream', () => {
|
|||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle all tool calls being cancelled', async () => {
|
||||
const toolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: { callId: '1', name: 'testTool', args: {} },
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
callId: '1',
|
||||
responseParts: [{ text: 'cancelled' }],
|
||||
error: undefined,
|
||||
resultDisplay: 'Tool 1 cancelled display',
|
||||
},
|
||||
responseSubmittedToGemini: false,
|
||||
tool: {
|
||||
name: 'testTool',
|
||||
description: 'desc',
|
||||
getDescription: vi.fn(),
|
||||
} as any,
|
||||
},
|
||||
];
|
||||
|
||||
const client = new MockedGeminiClientClass(mockConfig);
|
||||
const { mockMarkToolsAsSubmitted, rerender } = renderTestHook(
|
||||
toolCalls,
|
||||
client,
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
rerender({} as any);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
|
||||
expect(client.addHistory).toHaveBeenCalledTimes(2);
|
||||
expect(client.addHistory).toHaveBeenCalledWith({
|
||||
role: 'user',
|
||||
parts: [{ text: 'cancelled' }],
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -19,7 +19,7 @@ import {
|
|||
ToolCallRequestInfo,
|
||||
logUserPrompt,
|
||||
} from '@gemini-cli/core';
|
||||
import { type PartListUnion } from '@google/genai';
|
||||
import { type Part, type PartListUnion } from '@google/genai';
|
||||
import {
|
||||
StreamingState,
|
||||
HistoryItemWithoutId,
|
||||
|
@ -531,6 +531,41 @@ export const useGeminiStream = (
|
|||
completedAndReadyToSubmitTools.length > 0 &&
|
||||
completedAndReadyToSubmitTools.length === toolCalls.length
|
||||
) {
|
||||
// If all the tools were cancelled, don't submit a response to Gemini.
|
||||
const allToolsCancelled = completedAndReadyToSubmitTools.every(
|
||||
(tc) => tc.status === 'cancelled',
|
||||
);
|
||||
|
||||
if (allToolsCancelled) {
|
||||
if (geminiClient) {
|
||||
// We need to manually add the function responses to the history
|
||||
// so the model knows the tools were cancelled.
|
||||
const responsesToAdd = completedAndReadyToSubmitTools.flatMap(
|
||||
(toolCall) => toolCall.response.responseParts,
|
||||
);
|
||||
for (const response of responsesToAdd) {
|
||||
let parts: Part[];
|
||||
if (Array.isArray(response)) {
|
||||
parts = response;
|
||||
} else if (typeof response === 'string') {
|
||||
parts = [{ text: response }];
|
||||
} else {
|
||||
parts = [response];
|
||||
}
|
||||
geminiClient.addHistory({
|
||||
role: 'user',
|
||||
parts,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
|
||||
(toolCall) => toolCall.request.callId,
|
||||
);
|
||||
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||
return;
|
||||
}
|
||||
|
||||
const responsesToSend: PartListUnion[] =
|
||||
completedAndReadyToSubmitTools.map(
|
||||
(toolCall) => toolCall.response.responseParts,
|
||||
|
@ -542,7 +577,14 @@ export const useGeminiStream = (
|
|||
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||
submitQuery(mergePartListUnions(responsesToSend));
|
||||
}
|
||||
}, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]);
|
||||
}, [
|
||||
toolCalls,
|
||||
isResponding,
|
||||
submitQuery,
|
||||
markToolsAsSubmitted,
|
||||
addItem,
|
||||
geminiClient,
|
||||
]);
|
||||
|
||||
const pendingHistoryItems = [
|
||||
pendingHistoryItemRef.current,
|
||||
|
|
|
@ -219,4 +219,22 @@ describe('Gemini Client (client.ts)', () => {
|
|||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('addHistory', () => {
|
||||
it('should call chat.addHistory with the provided content', async () => {
|
||||
const mockChat = {
|
||||
addHistory: vi.fn(),
|
||||
};
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
client['chat'] = Promise.resolve(mockChat as any);
|
||||
|
||||
const newContent = {
|
||||
role: 'user',
|
||||
parts: [{ text: 'New history item' }],
|
||||
};
|
||||
await client.addHistory(newContent);
|
||||
|
||||
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -58,6 +58,11 @@ export class GeminiClient {
|
|||
this.chat = this.startChat();
|
||||
}
|
||||
|
||||
async addHistory(content: Content) {
|
||||
const chat = await this.chat;
|
||||
chat.addHistory(content);
|
||||
}
|
||||
|
||||
getChat(): Promise<GeminiChat> {
|
||||
return this.chat;
|
||||
}
|
||||
|
|
|
@ -352,4 +352,34 @@ describe('GeminiChat', () => {
|
|||
expect(history[1].parts).toEqual([{ text: 'Visible text' }]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('addHistory', () => {
|
||||
it('should add a new content item to the history', () => {
|
||||
const newContent: Content = {
|
||||
role: 'user',
|
||||
parts: [{ text: 'A new message' }],
|
||||
};
|
||||
chat.addHistory(newContent);
|
||||
const history = chat.getHistory();
|
||||
expect(history.length).toBe(1);
|
||||
expect(history[0]).toEqual(newContent);
|
||||
});
|
||||
|
||||
it('should add multiple items correctly', () => {
|
||||
const content1: Content = {
|
||||
role: 'user',
|
||||
parts: [{ text: 'Message 1' }],
|
||||
};
|
||||
const content2: Content = {
|
||||
role: 'model',
|
||||
parts: [{ text: 'Message 2' }],
|
||||
};
|
||||
chat.addHistory(content1);
|
||||
chat.addHistory(content2);
|
||||
const history = chat.getHistory();
|
||||
expect(history.length).toBe(2);
|
||||
expect(history[0]).toEqual(content1);
|
||||
expect(history[1]).toEqual(content2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -287,6 +287,15 @@ export class GeminiChat {
|
|||
return structuredClone(history);
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a new entry to the chat history.
|
||||
*
|
||||
* @param content - The content to add to the history.
|
||||
*/
|
||||
addHistory(content: Content): void {
|
||||
this.history.push(content);
|
||||
}
|
||||
|
||||
private async *processStreamResponse(
|
||||
streamResponse: AsyncGenerator<GenerateContentResponse>,
|
||||
inputContent: Content,
|
||||
|
|
Loading…
Reference in New Issue