fix: Ensure all tool calls are complete before submitting responses (#689)

This commit is contained in:
N. Taylor Mullen 2025-06-02 01:50:28 -07:00 committed by GitHub
parent 27ba28ef76
commit 34b81abd9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 312 additions and 7 deletions

View File

@ -4,19 +4,102 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import { mergePartListUnions } from './useGeminiStream.js';
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
import { renderHook, act, waitFor } from '@testing-library/react';
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
import {
useReactToolScheduler,
TrackedToolCall,
TrackedCompletedToolCall,
TrackedExecutingToolCall,
TrackedCancelledToolCall,
} from './useReactToolScheduler.js';
import { Config } from '@gemini-code/core';
import { Part, PartListUnion } from '@google/genai';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
// Mock useToolScheduler
vi.mock('./useReactToolScheduler', async () => {
const actual = await vi.importActual('./useReactToolScheduler');
// --- MOCKS ---
const mockSendMessageStream = vi
.fn()
.mockReturnValue((async function* () {})());
const mockStartChat = vi.fn();
vi.mock('@gemini-code/core', async (importOriginal) => {
const actualCoreModule = (await importOriginal()) as any;
const MockedGeminiClientClass = vi.fn().mockImplementation(function (
this: any,
_config: any,
) {
// _config
this.startChat = mockStartChat;
this.sendMessageStream = mockSendMessageStream;
});
return {
...actual, // We need mapToDisplay from actual
...(actualCoreModule || {}),
GeminiClient: MockedGeminiClientClass,
// GeminiChat will be from actualCoreModule if it exists, otherwise undefined
};
});
const mockUseReactToolScheduler = useReactToolScheduler as Mock;
vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
const actualSchedulerModule = (await importOriginal()) as any;
return {
...(actualSchedulerModule || {}),
useReactToolScheduler: vi.fn(),
};
});
vi.mock('ink', async (importOriginal) => {
const actualInkModule = (await importOriginal()) as any;
return { ...(actualInkModule || {}), useInput: vi.fn() };
});
vi.mock('./shellCommandProcessor.js', () => ({
useShellCommandProcessor: vi.fn().mockReturnValue({
handleShellCommand: vi.fn(),
}),
}));
vi.mock('./atCommandProcessor.js', () => ({
handleAtCommand: vi
.fn()
.mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }),
}));
vi.mock('../utils/markdownUtilities.js', () => ({
findLastSafeSplitPoint: vi.fn((s: string) => s.length),
}));
vi.mock('./useStateAndRef.js', () => ({
useStateAndRef: vi.fn((initial) => {
let val = initial;
const ref = { current: val };
const setVal = vi.fn((updater) => {
if (typeof updater === 'function') {
val = updater(val);
} else {
val = updater;
}
ref.current = val;
});
return [ref, setVal];
}),
}));
vi.mock('./useLogger.js', () => ({
useLogger: vi.fn().mockReturnValue({
logMessage: vi.fn().mockResolvedValue(undefined),
}),
}));
vi.mock('./slashCommandProcessor.js', () => ({
handleSlashCommand: vi.fn().mockReturnValue(false),
}));
// --- END MOCKS ---
describe('mergePartListUnions', () => {
it('should merge multiple PartListUnion arrays', () => {
const list1: PartListUnion = [{ text: 'Hello' }];
@ -135,3 +218,222 @@ describe('mergePartListUnions', () => {
]);
});
});
// --- Tests for useGeminiStream Hook ---
describe('useGeminiStream', () => {
let mockAddItem: Mock;
let mockSetShowHelp: Mock;
let mockConfig: Config;
let mockOnDebugMessage: Mock;
let mockHandleSlashCommand: Mock;
let mockScheduleToolCalls: Mock;
let mockCancelAllToolCalls: Mock;
let mockMarkToolsAsSubmitted: Mock;
beforeEach(() => {
vi.clearAllMocks(); // Clear mocks before each test
mockAddItem = vi.fn();
mockSetShowHelp = vi.fn();
mockConfig = {
apiKey: 'test-api-key',
model: 'gemini-pro',
sandbox: false,
targetDir: '/test/dir',
debugMode: false,
question: undefined,
fullContext: false,
coreTools: [],
toolDiscoveryCommand: undefined,
toolCallCommand: undefined,
mcpServerCommand: undefined,
mcpServers: undefined,
userAgent: 'test-agent',
userMemory: '',
geminiMdFileCount: 0,
alwaysSkipModificationConfirmation: false,
vertexai: false,
showMemoryUsage: false,
contextFileName: undefined,
getToolRegistry: vi.fn(
() => ({ getToolSchemaList: vi.fn(() => []) }) as any,
),
} as unknown as Config;
mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockReturnValue(false);
// Mock return value for useReactToolScheduler
mockScheduleToolCalls = vi.fn();
mockCancelAllToolCalls = vi.fn();
mockMarkToolsAsSubmitted = vi.fn();
// Default mock for useReactToolScheduler to prevent toolCalls being undefined initially
mockUseReactToolScheduler.mockReturnValue([
[], // Default to empty array for toolCalls
mockScheduleToolCalls,
mockCancelAllToolCalls,
mockMarkToolsAsSubmitted,
]);
// Reset mocks for GeminiClient instance methods (startChat and sendMessageStream)
// The GeminiClient constructor itself is mocked at the module level.
mockStartChat.mockClear().mockResolvedValue({
sendMessageStream: mockSendMessageStream,
} as unknown as any); // GeminiChat -> any
mockSendMessageStream
.mockClear()
.mockReturnValue((async function* () {})());
});
const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => {
mockUseReactToolScheduler.mockReturnValue([
initialToolCalls,
mockScheduleToolCalls,
mockCancelAllToolCalls,
mockMarkToolsAsSubmitted,
]);
const { result, rerender } = renderHook(() =>
useGeminiStream(
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
mockSetShowHelp,
mockConfig,
mockOnDebugMessage,
mockHandleSlashCommand,
false, // shellModeActive
),
);
return {
result,
rerender,
mockMarkToolsAsSubmitted,
mockSendMessageStream,
// mockFilter removed
};
};
it('should not submit tool responses if not all tool calls are completed', () => {
const toolCalls: TrackedToolCall[] = [
{
request: { callId: 'call1', name: 'tool1', args: {} },
status: 'success',
responseSubmittedToGemini: false,
response: {
callId: 'call1',
responseParts: [{ text: 'tool 1 response' }],
error: undefined,
resultDisplay: 'Tool 1 success display',
},
tool: {
name: 'tool1',
description: 'desc1',
getDescription: vi.fn(),
} as any,
startTime: Date.now(),
endTime: Date.now(),
} as TrackedCompletedToolCall,
{
request: { callId: 'call2', name: 'tool2', args: {} },
status: 'executing',
responseSubmittedToGemini: false,
tool: {
name: 'tool2',
description: 'desc2',
getDescription: vi.fn(),
} as any,
startTime: Date.now(),
liveOutput: '...',
} as TrackedExecutingToolCall,
];
const { mockMarkToolsAsSubmitted, mockSendMessageStream } =
renderTestHook(toolCalls);
// Effect for submitting tool responses depends on toolCalls and isResponding
// isResponding is initially false, so the effect should run.
expect(mockMarkToolsAsSubmitted).not.toHaveBeenCalled();
expect(mockSendMessageStream).not.toHaveBeenCalled(); // submitQuery uses this
});
it('should submit tool responses when all tool calls are completed and ready', async () => {
const toolCall1ResponseParts: PartListUnion = [
{ text: 'tool 1 final response' },
];
const toolCall2ResponseParts: PartListUnion = [
{ text: 'tool 2 final response' },
];
// Simplified toolCalls to ensure the filter logic is the focus
const simplifiedToolCalls: TrackedToolCall[] = [
{
request: { callId: 'call1', name: 'tool1', args: {} },
status: 'success',
responseSubmittedToGemini: false,
response: {
callId: 'call1',
responseParts: toolCall1ResponseParts,
error: undefined,
resultDisplay: 'Tool 1 success display',
},
tool: {
name: 'tool1',
description: 'desc',
getDescription: vi.fn(),
} as any,
startTime: Date.now(),
endTime: Date.now(),
} as TrackedCompletedToolCall,
{
request: { callId: 'call2', name: 'tool2', args: {} },
status: 'cancelled',
responseSubmittedToGemini: false,
response: {
callId: 'call2',
responseParts: toolCall2ResponseParts,
error: undefined,
resultDisplay: 'Tool 2 cancelled display',
},
tool: {
name: 'tool2',
description: 'desc',
getDescription: vi.fn(),
} as any,
startTime: Date.now(),
endTime: Date.now(),
reason: 'test cancellation',
} as TrackedCancelledToolCall,
];
let hookResult: any;
await act(async () => {
hookResult = renderTestHook(simplifiedToolCalls);
});
const {
mockMarkToolsAsSubmitted,
mockSendMessageStream: localMockSendMessageStream,
} = hookResult!;
// It seems the initial render + effect run should be enough.
// If rerender was for a specific state change, it might still be needed.
// For now, let's test if the initial effect run (covered by the first act) is sufficient.
// If not, we can add back: await act(async () => { rerender({}); });
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']);
await waitFor(() => {
expect(localMockSendMessageStream).toHaveBeenCalledTimes(1);
});
const expectedMergedResponse = mergePartListUnions([
toolCall1ResponseParts,
toolCall2ResponseParts,
]);
expect(localMockSendMessageStream).toHaveBeenCalledWith(
expect.anything(),
expectedMergedResponse,
expect.anything(),
);
});
});

View File

@ -530,7 +530,10 @@ export const useGeminiStream = (
},
);
if (completedAndReadyToSubmitTools.length > 0) {
if (
completedAndReadyToSubmitTools.length > 0 &&
completedAndReadyToSubmitTools.length === toolCalls.length
) {
const responsesToSend: PartListUnion[] =
completedAndReadyToSubmitTools.map(
(toolCall) => toolCall.response.responseParts,