fix(cli): correctly handle tool invocation cancellation (#844)
This commit is contained in:
parent
9efca40dae
commit
241c404573
|
@ -30,6 +30,7 @@ const MockedGeminiClientClass = vi.hoisted(() =>
|
||||||
// _config
|
// _config
|
||||||
this.startChat = mockStartChat;
|
this.startChat = mockStartChat;
|
||||||
this.sendMessageStream = mockSendMessageStream;
|
this.sendMessageStream = mockSendMessageStream;
|
||||||
|
this.addHistory = vi.fn();
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -267,6 +268,7 @@ describe('useGeminiStream', () => {
|
||||||
() => ({ getToolSchemaList: vi.fn(() => []) }) as any,
|
() => ({ getToolSchemaList: vi.fn(() => []) }) as any,
|
||||||
),
|
),
|
||||||
getGeminiClient: mockGetGeminiClient,
|
getGeminiClient: mockGetGeminiClient,
|
||||||
|
addHistory: vi.fn(),
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
mockOnDebugMessage = vi.fn();
|
mockOnDebugMessage = vi.fn();
|
||||||
mockHandleSlashCommand = vi.fn().mockReturnValue(false);
|
mockHandleSlashCommand = vi.fn().mockReturnValue(false);
|
||||||
|
@ -294,7 +296,10 @@ describe('useGeminiStream', () => {
|
||||||
.mockReturnValue((async function* () {})());
|
.mockReturnValue((async function* () {})());
|
||||||
});
|
});
|
||||||
|
|
||||||
const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => {
|
const renderTestHook = (
|
||||||
|
initialToolCalls: TrackedToolCall[] = [],
|
||||||
|
geminiClient?: any,
|
||||||
|
) => {
|
||||||
mockUseReactToolScheduler.mockReturnValue([
|
mockUseReactToolScheduler.mockReturnValue([
|
||||||
initialToolCalls,
|
initialToolCalls,
|
||||||
mockScheduleToolCalls,
|
mockScheduleToolCalls,
|
||||||
|
@ -302,9 +307,11 @@ describe('useGeminiStream', () => {
|
||||||
mockMarkToolsAsSubmitted,
|
mockMarkToolsAsSubmitted,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
const client = geminiClient || mockConfig.getGeminiClient();
|
||||||
|
|
||||||
const { result, rerender } = renderHook(() =>
|
const { result, rerender } = renderHook(() =>
|
||||||
useGeminiStream(
|
useGeminiStream(
|
||||||
mockConfig.getGeminiClient(),
|
client,
|
||||||
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||||
mockSetShowHelp,
|
mockSetShowHelp,
|
||||||
mockConfig,
|
mockConfig,
|
||||||
|
@ -318,6 +325,7 @@ describe('useGeminiStream', () => {
|
||||||
rerender,
|
rerender,
|
||||||
mockMarkToolsAsSubmitted,
|
mockMarkToolsAsSubmitted,
|
||||||
mockSendMessageStream,
|
mockSendMessageStream,
|
||||||
|
client,
|
||||||
// mockFilter removed
|
// mockFilter removed
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
@ -444,4 +452,44 @@ describe('useGeminiStream', () => {
|
||||||
expect.any(AbortSignal),
|
expect.any(AbortSignal),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle all tool calls being cancelled', async () => {
|
||||||
|
const toolCalls: TrackedToolCall[] = [
|
||||||
|
{
|
||||||
|
request: { callId: '1', name: 'testTool', args: {} },
|
||||||
|
status: 'cancelled',
|
||||||
|
response: {
|
||||||
|
callId: '1',
|
||||||
|
responseParts: [{ text: 'cancelled' }],
|
||||||
|
error: undefined,
|
||||||
|
resultDisplay: 'Tool 1 cancelled display',
|
||||||
|
},
|
||||||
|
responseSubmittedToGemini: false,
|
||||||
|
tool: {
|
||||||
|
name: 'testTool',
|
||||||
|
description: 'desc',
|
||||||
|
getDescription: vi.fn(),
|
||||||
|
} as any,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const client = new MockedGeminiClientClass(mockConfig);
|
||||||
|
const { mockMarkToolsAsSubmitted, rerender } = renderTestHook(
|
||||||
|
toolCalls,
|
||||||
|
client,
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
rerender({} as any);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
|
||||||
|
expect(client.addHistory).toHaveBeenCalledTimes(2);
|
||||||
|
expect(client.addHistory).toHaveBeenCalledWith({
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: 'cancelled' }],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -19,7 +19,7 @@ import {
|
||||||
ToolCallRequestInfo,
|
ToolCallRequestInfo,
|
||||||
logUserPrompt,
|
logUserPrompt,
|
||||||
} from '@gemini-cli/core';
|
} from '@gemini-cli/core';
|
||||||
import { type PartListUnion } from '@google/genai';
|
import { type Part, type PartListUnion } from '@google/genai';
|
||||||
import {
|
import {
|
||||||
StreamingState,
|
StreamingState,
|
||||||
HistoryItemWithoutId,
|
HistoryItemWithoutId,
|
||||||
|
@ -531,6 +531,41 @@ export const useGeminiStream = (
|
||||||
completedAndReadyToSubmitTools.length > 0 &&
|
completedAndReadyToSubmitTools.length > 0 &&
|
||||||
completedAndReadyToSubmitTools.length === toolCalls.length
|
completedAndReadyToSubmitTools.length === toolCalls.length
|
||||||
) {
|
) {
|
||||||
|
// If all the tools were cancelled, don't submit a response to Gemini.
|
||||||
|
const allToolsCancelled = completedAndReadyToSubmitTools.every(
|
||||||
|
(tc) => tc.status === 'cancelled',
|
||||||
|
);
|
||||||
|
|
||||||
|
if (allToolsCancelled) {
|
||||||
|
if (geminiClient) {
|
||||||
|
// We need to manually add the function responses to the history
|
||||||
|
// so the model knows the tools were cancelled.
|
||||||
|
const responsesToAdd = completedAndReadyToSubmitTools.flatMap(
|
||||||
|
(toolCall) => toolCall.response.responseParts,
|
||||||
|
);
|
||||||
|
for (const response of responsesToAdd) {
|
||||||
|
let parts: Part[];
|
||||||
|
if (Array.isArray(response)) {
|
||||||
|
parts = response;
|
||||||
|
} else if (typeof response === 'string') {
|
||||||
|
parts = [{ text: response }];
|
||||||
|
} else {
|
||||||
|
parts = [response];
|
||||||
|
}
|
||||||
|
geminiClient.addHistory({
|
||||||
|
role: 'user',
|
||||||
|
parts,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
|
||||||
|
(toolCall) => toolCall.request.callId,
|
||||||
|
);
|
||||||
|
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const responsesToSend: PartListUnion[] =
|
const responsesToSend: PartListUnion[] =
|
||||||
completedAndReadyToSubmitTools.map(
|
completedAndReadyToSubmitTools.map(
|
||||||
(toolCall) => toolCall.response.responseParts,
|
(toolCall) => toolCall.response.responseParts,
|
||||||
|
@ -542,7 +577,14 @@ export const useGeminiStream = (
|
||||||
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||||
submitQuery(mergePartListUnions(responsesToSend));
|
submitQuery(mergePartListUnions(responsesToSend));
|
||||||
}
|
}
|
||||||
}, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]);
|
}, [
|
||||||
|
toolCalls,
|
||||||
|
isResponding,
|
||||||
|
submitQuery,
|
||||||
|
markToolsAsSubmitted,
|
||||||
|
addItem,
|
||||||
|
geminiClient,
|
||||||
|
]);
|
||||||
|
|
||||||
const pendingHistoryItems = [
|
const pendingHistoryItems = [
|
||||||
pendingHistoryItemRef.current,
|
pendingHistoryItemRef.current,
|
||||||
|
|
|
@ -219,4 +219,22 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('addHistory', () => {
|
||||||
|
it('should call chat.addHistory with the provided content', async () => {
|
||||||
|
const mockChat = {
|
||||||
|
addHistory: vi.fn(),
|
||||||
|
};
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
client['chat'] = Promise.resolve(mockChat as any);
|
||||||
|
|
||||||
|
const newContent = {
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: 'New history item' }],
|
||||||
|
};
|
||||||
|
await client.addHistory(newContent);
|
||||||
|
|
||||||
|
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -58,6 +58,11 @@ export class GeminiClient {
|
||||||
this.chat = this.startChat();
|
this.chat = this.startChat();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async addHistory(content: Content) {
|
||||||
|
const chat = await this.chat;
|
||||||
|
chat.addHistory(content);
|
||||||
|
}
|
||||||
|
|
||||||
getChat(): Promise<GeminiChat> {
|
getChat(): Promise<GeminiChat> {
|
||||||
return this.chat;
|
return this.chat;
|
||||||
}
|
}
|
||||||
|
|
|
@ -352,4 +352,34 @@ describe('GeminiChat', () => {
|
||||||
expect(history[1].parts).toEqual([{ text: 'Visible text' }]);
|
expect(history[1].parts).toEqual([{ text: 'Visible text' }]);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('addHistory', () => {
|
||||||
|
it('should add a new content item to the history', () => {
|
||||||
|
const newContent: Content = {
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: 'A new message' }],
|
||||||
|
};
|
||||||
|
chat.addHistory(newContent);
|
||||||
|
const history = chat.getHistory();
|
||||||
|
expect(history.length).toBe(1);
|
||||||
|
expect(history[0]).toEqual(newContent);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should add multiple items correctly', () => {
|
||||||
|
const content1: Content = {
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: 'Message 1' }],
|
||||||
|
};
|
||||||
|
const content2: Content = {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Message 2' }],
|
||||||
|
};
|
||||||
|
chat.addHistory(content1);
|
||||||
|
chat.addHistory(content2);
|
||||||
|
const history = chat.getHistory();
|
||||||
|
expect(history.length).toBe(2);
|
||||||
|
expect(history[0]).toEqual(content1);
|
||||||
|
expect(history[1]).toEqual(content2);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -287,6 +287,15 @@ export class GeminiChat {
|
||||||
return structuredClone(history);
|
return structuredClone(history);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adds a new entry to the chat history.
|
||||||
|
*
|
||||||
|
* @param content - The content to add to the history.
|
||||||
|
*/
|
||||||
|
addHistory(content: Content): void {
|
||||||
|
this.history.push(content);
|
||||||
|
}
|
||||||
|
|
||||||
private async *processStreamResponse(
|
private async *processStreamResponse(
|
||||||
streamResponse: AsyncGenerator<GenerateContentResponse>,
|
streamResponse: AsyncGenerator<GenerateContentResponse>,
|
||||||
inputContent: Content,
|
inputContent: Content,
|
||||||
|
|
Loading…
Reference in New Issue