/** * @license * Copyright 2025 Google LLC * SPDX-License-Identifier: Apache-2.0 */ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, it, expect, vi, beforeEach, Mock } from 'vitest'; import { renderHook, act, waitFor } from '@testing-library/react'; import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js'; import { useReactToolScheduler, TrackedToolCall, TrackedCompletedToolCall, TrackedExecutingToolCall, TrackedCancelledToolCall, } from './useReactToolScheduler.js'; import { Config } from '@gemini-cli/core'; import { Part, PartListUnion } from '@google/genai'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; // --- MOCKS --- const mockSendMessageStream = vi .fn() .mockReturnValue((async function* () {})()); const mockStartChat = vi.fn(); const MockedGeminiClientClass = vi.hoisted(() => vi.fn().mockImplementation(function (this: any, _config: any) { // _config this.startChat = mockStartChat; this.sendMessageStream = mockSendMessageStream; this.addHistory = vi.fn(); }), ); vi.mock('@gemini-cli/core', async (importOriginal) => { const actualCoreModule = (await importOriginal()) as any; return { ...(actualCoreModule || {}), GeminiClient: MockedGeminiClientClass, // Export the class for type checking or other direct uses Config: actualCoreModule.Config, // Ensure Config is passed through }; }); const mockUseReactToolScheduler = useReactToolScheduler as Mock; vi.mock('./useReactToolScheduler.js', async (importOriginal) => { const actualSchedulerModule = (await importOriginal()) as any; return { ...(actualSchedulerModule || {}), useReactToolScheduler: vi.fn(), }; }); vi.mock('ink', async (importOriginal) => { const actualInkModule = (await importOriginal()) as any; return { ...(actualInkModule || {}), useInput: vi.fn() }; }); vi.mock('./shellCommandProcessor.js', () => ({ useShellCommandProcessor: vi.fn().mockReturnValue({ handleShellCommand: vi.fn(), }), })); vi.mock('./atCommandProcessor.js', () => ({ handleAtCommand: vi .fn() .mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }), })); vi.mock('../utils/markdownUtilities.js', () => ({ findLastSafeSplitPoint: vi.fn((s: string) => s.length), })); vi.mock('./useStateAndRef.js', () => ({ useStateAndRef: vi.fn((initial) => { let val = initial; const ref = { current: val }; const setVal = vi.fn((updater) => { if (typeof updater === 'function') { val = updater(val); } else { val = updater; } ref.current = val; }); return [ref, setVal]; }), })); vi.mock('./useLogger.js', () => ({ useLogger: vi.fn().mockReturnValue({ logMessage: vi.fn().mockResolvedValue(undefined), }), })); vi.mock('./slashCommandProcessor.js', () => ({ handleSlashCommand: vi.fn().mockReturnValue(false), })); // --- END MOCKS --- describe('mergePartListUnions', () => { it('should merge multiple PartListUnion arrays', () => { const list1: PartListUnion = [{ text: 'Hello' }]; const list2: PartListUnion = [ { inlineData: { mimeType: 'image/png', data: 'abc' } }, ]; const list3: PartListUnion = [{ text: 'World' }, { text: '!' }]; const result = mergePartListUnions([list1, list2, list3]); expect(result).toEqual([ { text: 'Hello' }, { inlineData: { mimeType: 'image/png', data: 'abc' } }, { text: 'World' }, { text: '!' }, ]); }); it('should handle empty arrays in the input list', () => { const list1: PartListUnion = [{ text: 'First' }]; const list2: PartListUnion = []; const list3: PartListUnion = [{ text: 'Last' }]; const result = mergePartListUnions([list1, list2, list3]); expect(result).toEqual([{ text: 'First' }, { text: 'Last' }]); }); it('should handle a single PartListUnion array', () => { const list1: PartListUnion = [ { text: 'One' }, { inlineData: { mimeType: 'image/jpeg', data: 'xyz' } }, ]; const result = mergePartListUnions([list1]); expect(result).toEqual(list1); }); it('should return an empty array if all input arrays are empty', () => { const list1: PartListUnion = []; const list2: PartListUnion = []; const result = mergePartListUnions([list1, list2]); expect(result).toEqual([]); }); it('should handle input list being empty', () => { const result = mergePartListUnions([]); expect(result).toEqual([]); }); it('should correctly merge when PartListUnion items are single Parts not in arrays', () => { const part1: Part = { text: 'Single part 1' }; const part2: Part = { inlineData: { mimeType: 'image/gif', data: 'gif' } }; const listContainingSingleParts: PartListUnion[] = [ part1, [part2], { text: 'Another single part' }, ]; const result = mergePartListUnions(listContainingSingleParts); expect(result).toEqual([ { text: 'Single part 1' }, { inlineData: { mimeType: 'image/gif', data: 'gif' } }, { text: 'Another single part' }, ]); }); it('should handle a mix of arrays and single parts, including empty arrays and undefined/null parts if they were possible (though PartListUnion typing restricts this)', () => { const list1: PartListUnion = [{ text: 'A' }]; const list2: PartListUnion = []; const part3: Part = { text: 'B' }; const list4: PartListUnion = [ { text: 'C' }, { inlineData: { mimeType: 'text/plain', data: 'D' } }, ]; const result = mergePartListUnions([list1, list2, part3, list4]); expect(result).toEqual([ { text: 'A' }, { text: 'B' }, { text: 'C' }, { inlineData: { mimeType: 'text/plain', data: 'D' } }, ]); }); it('should preserve the order of parts from the input arrays', () => { const listA: PartListUnion = [{ text: '1' }, { text: '2' }]; const listB: PartListUnion = [{ text: '3' }]; const listC: PartListUnion = [{ text: '4' }, { text: '5' }]; const result = mergePartListUnions([listA, listB, listC]); expect(result).toEqual([ { text: '1' }, { text: '2' }, { text: '3' }, { text: '4' }, { text: '5' }, ]); }); it('should handle cases where some PartListUnion items are single Parts and others are arrays of Parts', () => { const singlePart1: Part = { text: 'First single' }; const arrayPart1: Part[] = [ { text: 'Array item 1' }, { text: 'Array item 2' }, ]; const singlePart2: Part = { inlineData: { mimeType: 'application/json', data: 'e30=' }, }; // {} const arrayPart2: Part[] = [{ text: 'Last array item' }]; const result = mergePartListUnions([ singlePart1, arrayPart1, singlePart2, arrayPart2, ]); expect(result).toEqual([ { text: 'First single' }, { text: 'Array item 1' }, { text: 'Array item 2' }, { inlineData: { mimeType: 'application/json', data: 'e30=' } }, { text: 'Last array item' }, ]); }); }); // --- Tests for useGeminiStream Hook --- describe('useGeminiStream', () => { let mockAddItem: Mock; let mockSetShowHelp: Mock; let mockConfig: Config; let mockOnDebugMessage: Mock; let mockHandleSlashCommand: Mock; let mockScheduleToolCalls: Mock; let mockCancelAllToolCalls: Mock; let mockMarkToolsAsSubmitted: Mock; beforeEach(() => { vi.clearAllMocks(); // Clear mocks before each test mockAddItem = vi.fn(); mockSetShowHelp = vi.fn(); // Define the mock for getGeminiClient const mockGetGeminiClient = vi.fn().mockImplementation(() => { // MockedGeminiClientClass is defined in the module scope by the previous change. // It will use the mockStartChat and mockSendMessageStream that are managed within beforeEach. const clientInstance = new MockedGeminiClientClass(mockConfig); return clientInstance; }); mockConfig = { apiKey: 'test-api-key', model: 'gemini-pro', sandbox: false, targetDir: '/test/dir', debugMode: false, question: undefined, fullContext: false, coreTools: [], toolDiscoveryCommand: undefined, toolCallCommand: undefined, mcpServerCommand: undefined, mcpServers: undefined, userAgent: 'test-agent', userMemory: '', geminiMdFileCount: 0, alwaysSkipModificationConfirmation: false, vertexai: false, showMemoryUsage: false, contextFileName: undefined, getToolRegistry: vi.fn( () => ({ getToolSchemaList: vi.fn(() => []) }) as any, ), getGeminiClient: mockGetGeminiClient, addHistory: vi.fn(), } as unknown as Config; mockOnDebugMessage = vi.fn(); mockHandleSlashCommand = vi.fn().mockReturnValue(false); // Mock return value for useReactToolScheduler mockScheduleToolCalls = vi.fn(); mockCancelAllToolCalls = vi.fn(); mockMarkToolsAsSubmitted = vi.fn(); // Default mock for useReactToolScheduler to prevent toolCalls being undefined initially mockUseReactToolScheduler.mockReturnValue([ [], // Default to empty array for toolCalls mockScheduleToolCalls, mockCancelAllToolCalls, mockMarkToolsAsSubmitted, ]); // Reset mocks for GeminiClient instance methods (startChat and sendMessageStream) // The GeminiClient constructor itself is mocked at the module level. mockStartChat.mockClear().mockResolvedValue({ sendMessageStream: mockSendMessageStream, } as unknown as any); // GeminiChat -> any mockSendMessageStream .mockClear() .mockReturnValue((async function* () {})()); }); const renderTestHook = ( initialToolCalls: TrackedToolCall[] = [], geminiClient?: any, ) => { mockUseReactToolScheduler.mockReturnValue([ initialToolCalls, mockScheduleToolCalls, mockCancelAllToolCalls, mockMarkToolsAsSubmitted, ]); const client = geminiClient || mockConfig.getGeminiClient(); const { result, rerender } = renderHook(() => useGeminiStream( client, mockAddItem as unknown as UseHistoryManagerReturn['addItem'], mockSetShowHelp, mockConfig, mockOnDebugMessage, mockHandleSlashCommand, false, // shellModeActive ), ); return { result, rerender, mockMarkToolsAsSubmitted, mockSendMessageStream, client, // mockFilter removed }; }; it('should not submit tool responses if not all tool calls are completed', () => { const toolCalls: TrackedToolCall[] = [ { request: { callId: 'call1', name: 'tool1', args: {} }, status: 'success', responseSubmittedToGemini: false, response: { callId: 'call1', responseParts: [{ text: 'tool 1 response' }], error: undefined, resultDisplay: 'Tool 1 success display', }, tool: { name: 'tool1', description: 'desc1', getDescription: vi.fn(), } as any, startTime: Date.now(), endTime: Date.now(), } as TrackedCompletedToolCall, { request: { callId: 'call2', name: 'tool2', args: {} }, status: 'executing', responseSubmittedToGemini: false, tool: { name: 'tool2', description: 'desc2', getDescription: vi.fn(), } as any, startTime: Date.now(), liveOutput: '...', } as TrackedExecutingToolCall, ]; const { mockMarkToolsAsSubmitted, mockSendMessageStream } = renderTestHook(toolCalls); // Effect for submitting tool responses depends on toolCalls and isResponding // isResponding is initially false, so the effect should run. expect(mockMarkToolsAsSubmitted).not.toHaveBeenCalled(); expect(mockSendMessageStream).not.toHaveBeenCalled(); // submitQuery uses this }); it('should submit tool responses when all tool calls are completed and ready', async () => { const toolCall1ResponseParts: PartListUnion = [ { text: 'tool 1 final response' }, ]; const toolCall2ResponseParts: PartListUnion = [ { text: 'tool 2 final response' }, ]; // Simplified toolCalls to ensure the filter logic is the focus const simplifiedToolCalls: TrackedToolCall[] = [ { request: { callId: 'call1', name: 'tool1', args: {} }, status: 'success', responseSubmittedToGemini: false, response: { callId: 'call1', responseParts: toolCall1ResponseParts, error: undefined, resultDisplay: 'Tool 1 success display', }, tool: { name: 'tool1', description: 'desc', getDescription: vi.fn(), } as any, startTime: Date.now(), endTime: Date.now(), } as TrackedCompletedToolCall, { request: { callId: 'call2', name: 'tool2', args: {} }, status: 'cancelled', responseSubmittedToGemini: false, response: { callId: 'call2', responseParts: toolCall2ResponseParts, error: undefined, resultDisplay: 'Tool 2 cancelled display', }, tool: { name: 'tool2', description: 'desc', getDescription: vi.fn(), } as any, startTime: Date.now(), endTime: Date.now(), reason: 'test cancellation', } as TrackedCancelledToolCall, ]; const hookResult = await act(async () => renderTestHook(simplifiedToolCalls), ); const { mockMarkToolsAsSubmitted, mockSendMessageStream: localMockSendMessageStream, } = hookResult!; // It seems the initial render + effect run should be enough. // If rerender was for a specific state change, it might still be needed. // For now, let's test if the initial effect run (covered by the first act) is sufficient. // If not, we can add back: await act(async () => { rerender({}); }); expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['call1', 'call2']); await waitFor(() => { expect(localMockSendMessageStream).toHaveBeenCalledTimes(1); }); const expectedMergedResponse = mergePartListUnions([ toolCall1ResponseParts, toolCall2ResponseParts, ]); expect(localMockSendMessageStream).toHaveBeenCalledWith( expectedMergedResponse, expect.any(AbortSignal), ); }); it('should handle all tool calls being cancelled', async () => { const toolCalls: TrackedToolCall[] = [ { request: { callId: '1', name: 'testTool', args: {} }, status: 'cancelled', response: { callId: '1', responseParts: [{ text: 'cancelled' }], error: undefined, resultDisplay: 'Tool 1 cancelled display', }, responseSubmittedToGemini: false, tool: { name: 'testTool', description: 'desc', getDescription: vi.fn(), } as any, }, ]; const client = new MockedGeminiClientClass(mockConfig); const { mockMarkToolsAsSubmitted, rerender } = renderTestHook( toolCalls, client, ); await act(async () => { rerender({} as any); }); await waitFor(() => { expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']); expect(client.addHistory).toHaveBeenCalledTimes(2); expect(client.addHistory).toHaveBeenCalledWith({ role: 'user', parts: [{ text: 'cancelled' }], }); }); }); });