fix: Ensure all tool calls are complete before submitting responses (#689)
This commit is contained in:
parent
27ba28ef76
commit
34b81abd9c
|
@ -4,19 +4,102 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { mergePartListUnions } from './useGeminiStream.js';
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
|
||||
import { renderHook, act, waitFor } from '@testing-library/react';
|
||||
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
|
||||
import {
|
||||
useReactToolScheduler,
|
||||
TrackedToolCall,
|
||||
TrackedCompletedToolCall,
|
||||
TrackedExecutingToolCall,
|
||||
TrackedCancelledToolCall,
|
||||
} from './useReactToolScheduler.js';
|
||||
import { Config } from '@gemini-code/core';
|
||||
import { Part, PartListUnion } from '@google/genai';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
|
||||
// Mock useToolScheduler
|
||||
vi.mock('./useReactToolScheduler', async () => {
|
||||
const actual = await vi.importActual('./useReactToolScheduler');
|
||||
// --- MOCKS ---
|
||||
const mockSendMessageStream = vi
|
||||
.fn()
|
||||
.mockReturnValue((async function* () {})());
|
||||
const mockStartChat = vi.fn();
|
||||
|
||||
vi.mock('@gemini-code/core', async (importOriginal) => {
|
||||
const actualCoreModule = (await importOriginal()) as any;
|
||||
const MockedGeminiClientClass = vi.fn().mockImplementation(function (
|
||||
this: any,
|
||||
_config: any,
|
||||
) {
|
||||
// _config
|
||||
this.startChat = mockStartChat;
|
||||
this.sendMessageStream = mockSendMessageStream;
|
||||
});
|
||||
return {
|
||||
...actual, // We need mapToDisplay from actual
|
||||
...(actualCoreModule || {}),
|
||||
GeminiClient: MockedGeminiClientClass,
|
||||
// GeminiChat will be from actualCoreModule if it exists, otherwise undefined
|
||||
};
|
||||
});
|
||||
|
||||
const mockUseReactToolScheduler = useReactToolScheduler as Mock;
|
||||
vi.mock('./useReactToolScheduler.js', async (importOriginal) => {
|
||||
const actualSchedulerModule = (await importOriginal()) as any;
|
||||
return {
|
||||
...(actualSchedulerModule || {}),
|
||||
useReactToolScheduler: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('ink', async (importOriginal) => {
|
||||
const actualInkModule = (await importOriginal()) as any;
|
||||
return { ...(actualInkModule || {}), useInput: vi.fn() };
|
||||
});
|
||||
|
||||
vi.mock('./shellCommandProcessor.js', () => ({
|
||||
useShellCommandProcessor: vi.fn().mockReturnValue({
|
||||
handleShellCommand: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('./atCommandProcessor.js', () => ({
|
||||
handleAtCommand: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/markdownUtilities.js', () => ({
|
||||
findLastSafeSplitPoint: vi.fn((s: string) => s.length),
|
||||
}));
|
||||
|
||||
vi.mock('./useStateAndRef.js', () => ({
|
||||
useStateAndRef: vi.fn((initial) => {
|
||||
let val = initial;
|
||||
const ref = { current: val };
|
||||
const setVal = vi.fn((updater) => {
|
||||
if (typeof updater === 'function') {
|
||||
val = updater(val);
|
||||
} else {
|
||||
val = updater;
|
||||
}
|
||||
ref.current = val;
|
||||
});
|
||||
return [ref, setVal];
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('./useLogger.js', () => ({
|
||||
useLogger: vi.fn().mockReturnValue({
|
||||
logMessage: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('./slashCommandProcessor.js', () => ({
|
||||
handleSlashCommand: vi.fn().mockReturnValue(false),
|
||||
}));
|
||||
|
||||
// --- END MOCKS ---
|
||||
|
||||
describe('mergePartListUnions', () => {
|
||||
it('should merge multiple PartListUnion arrays', () => {
|
||||
const list1: PartListUnion = [{ text: 'Hello' }];
|
||||
|
@ -135,3 +218,222 @@ describe('mergePartListUnions', () => {
|
|||
]);
|
||||
});
|
||||
});
|
||||
|
||||
// --- Tests for useGeminiStream Hook ---
|
||||
describe('useGeminiStream', () => {
|
||||
let mockAddItem: Mock;
|
||||
let mockSetShowHelp: Mock;
|
||||
let mockConfig: Config;
|
||||
let mockOnDebugMessage: Mock;
|
||||
let mockHandleSlashCommand: Mock;
|
||||
let mockScheduleToolCalls: Mock;
|
||||
let mockCancelAllToolCalls: Mock;
|
||||
let mockMarkToolsAsSubmitted: Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks(); // Clear mocks before each test
|
||||
|
||||
mockAddItem = vi.fn();
|
||||
mockSetShowHelp = vi.fn();
|
||||
mockConfig = {
|
||||
apiKey: 'test-api-key',
|
||||
model: 'gemini-pro',
|
||||
sandbox: false,
|
||||
targetDir: '/test/dir',
|
||||
debugMode: false,
|
||||
question: undefined,
|
||||
fullContext: false,
|
||||
coreTools: [],
|
||||
toolDiscoveryCommand: undefined,
|
||||
toolCallCommand: undefined,
|
||||
mcpServerCommand: undefined,
|
||||
mcpServers: undefined,
|
||||
userAgent: 'test-agent',
|
||||
userMemory: '',
|
||||
geminiMdFileCount: 0,
|
||||
alwaysSkipModificationConfirmation: false,
|
||||
vertexai: false,
|
||||
showMemoryUsage: false,
|
||||
contextFileName: undefined,
|
||||
getToolRegistry: vi.fn(
|
||||
() => ({ getToolSchemaList: vi.fn(() => []) }) as any,
|
||||
),
|
||||
} as unknown as Config;
|
||||
mockOnDebugMessage = vi.fn();
|
||||
mockHandleSlashCommand = vi.fn().mockReturnValue(false);
|
||||
|
||||
// Mock return value for useReactToolScheduler
|
||||
mockScheduleToolCalls = vi.fn();
|
||||
mockCancelAllToolCalls = vi.fn();
|
||||
mockMarkToolsAsSubmitted = vi.fn();
|
||||
|
||||
// Default mock for useReactToolScheduler to prevent toolCalls being undefined initially
|
||||
mockUseReactToolScheduler.mockReturnValue([
|
||||
[], // Default to empty array for toolCalls
|
||||
mockScheduleToolCalls,
|
||||
mockCancelAllToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
]);
|
||||
|
||||
// Reset mocks for GeminiClient instance methods (startChat and sendMessageStream)
|
||||
// The GeminiClient constructor itself is mocked at the module level.
|
||||
mockStartChat.mockClear().mockResolvedValue({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
} as unknown as any); // GeminiChat -> any
|
||||
mockSendMessageStream
|
||||
.mockClear()
|
||||
.mockReturnValue((async function* () {})());
|
||||
});
|
||||
|
||||
const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => {
|
||||
mockUseReactToolScheduler.mockReturnValue([
|
||||
initialToolCalls,
|
||||
mockScheduleToolCalls,
|
||||
mockCancelAllToolCalls,
|
||||
mockMarkToolsAsSubmitted,
|
||||
]);
|
||||
|
||||
const { result, rerender } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
mockSetShowHelp,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false, // shellModeActive
|
||||
),
|
||||
);
|
||||
return {
|
||||
result,
|
||||
rerender,
|
||||
mockMarkToolsAsSubmitted,
|
||||
mockSendMessageStream,
|
||||
// mockFilter removed
|
||||
};
|
||||
};
|
||||
|
||||
it('should not submit tool responses if not all tool calls are completed', () => {
|
||||
const toolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: { callId: 'call1', name: 'tool1', args: {} },
|
||||
status: 'success',
|
||||
responseSubmittedToGemini: false,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
responseParts: [{ text: 'tool 1 response' }],
|
||||
error: undefined,
|
||||
resultDisplay: 'Tool 1 success display',
|
||||
},
|
||||
tool: {
|
||||
name: 'tool1',
|
||||
description: 'desc1',
|
||||
getDescription: vi.fn(),
|
||||
} as any,
|
||||
startTime: Date.now(),
|
||||
endTime: Date.now(),
|
||||
} as TrackedCompletedToolCall,
|
||||
{
|
||||
request: { callId: 'call2', name: 'tool2', args: {} },
|
||||
status: 'executing',
|
||||
responseSubmittedToGemini: false,
|
||||
tool: {
|
||||
name: 'tool2',
|
||||
description: 'desc2',
|
||||
getDescription: vi.fn(),
|
||||
} as any,
|
||||
startTime: Date.now(),
|
||||
liveOutput: '...',
|
||||
} as TrackedExecutingToolCall,
|
||||
];
|
||||
|
||||
const { mockMarkToolsAsSubmitted, mockSendMessageStream } =
|
||||
renderTestHook(toolCalls);
|
||||
|
||||
// Effect for submitting tool responses depends on toolCalls and isResponding
|
||||
// isResponding is initially false, so the effect should run.
|
||||
|
||||
expect(mockMarkToolsAsSubmitted).not.toHaveBeenCalled();
|
||||
expect(mockSendMessageStream).not.toHaveBeenCalled(); // submitQuery uses this
|
||||
});
|
||||
|
||||
it('should submit tool responses when all tool calls are completed and ready', async () => {
|
||||
const toolCall1ResponseParts: PartListUnion = [
|
||||
{ text: 'tool 1 final response' },
|
||||
];
|
||||
const toolCall2ResponseParts: PartListUnion = [
|
||||
{ text: 'tool 2 final response' },
|
||||
];
|
||||
|
||||
// Simplified toolCalls to ensure the filter logic is the focus
|
||||
const simplifiedToolCalls: TrackedToolCall[] = [
|
||||
{
|
||||
request: { callId: 'call1', name: 'tool1', args: {} },
|
||||
status: 'success',
|
||||
responseSubmittedToGemini: false,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
responseParts: toolCall1ResponseParts,
|
||||
error: undefined,
|
||||
resultDisplay: 'Tool 1 success display',
|
||||
},
|
||||
tool: {
|
||||
name: 'tool1',
|
||||
description: 'desc',
|
||||
getDescription: vi.fn(),
|
||||
} as any,
|
||||
startTime: Date.now(),
|
||||
endTime: Date.now(),
|
||||
} as TrackedCompletedToolCall,
|
||||
{
|
||||
request: { callId: 'call2', name: 'tool2', args: {} },
|
||||
status: 'cancelled',
|
||||
responseSubmittedToGemini: false,
|
||||
response: {
|
||||
callId: 'call2',
|
||||
responseParts: toolCall2ResponseParts,
|
||||
error: undefined,
|
||||
resultDisplay: 'Tool 2 cancelled display',
|
||||
},
|
||||
tool: {
|
||||
name: 'tool2',
|
||||
description: 'desc',
|
||||
getDescription: vi.fn(),
|
||||
} as any,
|
||||
startTime: Date.now(),
|
||||
endTime: Date.now(),
|
||||
reason: 'test cancellation',
|
||||
} as TrackedCancelledToolCall,
|
||||
];
|
||||
|
||||
let hookResult: any;
|
||||
await act(async () => {
|
||||
hookResult = renderTestHook(simplifiedToolCalls);
|
||||
});
|
||||
|
||||
const {
|
||||
mockMarkToolsAsSubmitted,
|
||||
mockSendMessageStream: localMockSendMessageStream,
|
||||
} = hookResult!;
|
||||
|
||||
// It seems the initial render + effect run should be enough.
|
||||
// If rerender was for a specific state change, it might still be needed.
|
||||
// For now, let's test if the initial effect run (covered by the first act) is sufficient.
|
||||
// If not, we can add back: await act(async () => { rerender({}); });
|
||||
|
||||
expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(localMockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
const expectedMergedResponse = mergePartListUnions([
|
||||
toolCall1ResponseParts,
|
||||
toolCall2ResponseParts,
|
||||
]);
|
||||
expect(localMockSendMessageStream).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expectedMergedResponse,
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -530,7 +530,10 @@ export const useGeminiStream = (
|
|||
},
|
||||
);
|
||||
|
||||
if (completedAndReadyToSubmitTools.length > 0) {
|
||||
if (
|
||||
completedAndReadyToSubmitTools.length > 0 &&
|
||||
completedAndReadyToSubmitTools.length === toolCalls.length
|
||||
) {
|
||||
const responsesToSend: PartListUnion[] =
|
||||
completedAndReadyToSubmitTools.map(
|
||||
(toolCall) => toolCall.response.responseParts,
|
||||
|
|
Loading…
Reference in New Issue