diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts index 7403f788..837e0d32 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts @@ -98,10 +98,6 @@ describe('handleAtCommand', () => { processedQuery: [{ text: query }], shouldProceed: true, }); - expect(mockAddItem).toHaveBeenCalledWith( - { type: 'user', text: query }, - 123, - ); }); it('should pass through original query if only a lone @ symbol is present', async () => { @@ -120,10 +116,6 @@ describe('handleAtCommand', () => { processedQuery: [{ text: queryWithSpaces }], shouldProceed: true, }); - expect(mockAddItem).toHaveBeenCalledWith( - { type: 'user', text: queryWithSpaces }, - 124, - ); expect(mockOnDebugMessage).toHaveBeenCalledWith( 'Lone @ detected, will be treated as text in the modified query.', ); @@ -156,10 +148,6 @@ describe('handleAtCommand', () => { ], shouldProceed: true, }); - expect(mockAddItem).toHaveBeenCalledWith( - { type: 'user', text: query }, - 125, - ); expect(mockAddItem).toHaveBeenCalledWith( expect.objectContaining({ type: 'tool_group', @@ -198,10 +186,6 @@ describe('handleAtCommand', () => { ], shouldProceed: true, }); - expect(mockAddItem).toHaveBeenCalledWith( - { type: 'user', text: query }, - 126, - ); expect(mockOnDebugMessage).toHaveBeenCalledWith( `Path ${dirPath} resolved to directory, using glob: ${resolvedGlob}`, ); @@ -236,10 +220,6 @@ describe('handleAtCommand', () => { ], shouldProceed: true, }); - expect(mockAddItem).toHaveBeenCalledWith( - { type: 'user', text: query }, - 128, - ); }); it('should correctly unescape paths with escaped spaces', async () => { @@ -270,10 +250,6 @@ describe('handleAtCommand', () => { ], shouldProceed: true, }); - expect(mockAddItem).toHaveBeenCalledWith( - { type: 'user', text: query }, - 125, - ); expect(mockAddItem).toHaveBeenCalledWith( expect.objectContaining({ type: 'tool_group', @@ -1090,4 +1066,37 @@ describe('handleAtCommand', () => { }); }); }); + + it("should not add the user's turn to history, as that is the caller's responsibility", async () => { + // Arrange + const fileContent = 'This is the file content.'; + const filePath = await createTestFile( + path.join(testRootDir, 'path', 'to', 'another-file.txt'), + fileContent, + ); + const query = `A query with @${filePath}`; + + // Act + await handleAtCommand({ + query, + config: mockConfig, + addItem: mockAddItem, + onDebugMessage: mockOnDebugMessage, + messageId: 999, + signal: abortController.signal, + }); + + // Assert + // It SHOULD be called for the tool_group + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ type: 'tool_group' }), + 999, + ); + + // It should NOT have been called for the user turn + const userTurnCalls = mockAddItem.mock.calls.filter( + (call) => call[0].type === 'user', + ); + expect(userTurnCalls).toHaveLength(0); + }); }); diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 85ad6f6f..3d139db8 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -137,12 +137,9 @@ export async function handleAtCommand({ ); if (atPathCommandParts.length === 0) { - addItem({ type: 'user', text: query }, userMessageTimestamp); return { processedQuery: [{ text: query }], shouldProceed: true }; } - addItem({ type: 'user', text: query }, userMessageTimestamp); - // Get centralized file discovery service const fileDiscovery = config.getFileService(); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 9eed0912..f08f6606 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -5,10 +5,19 @@ */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { describe, it, expect, vi, beforeEach, Mock } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + Mock, + MockInstance, +} from 'vitest'; import { renderHook, act, waitFor } from '@testing-library/react'; import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js'; import { useKeypress } from './useKeypress.js'; +import * as atCommandProcessor from './atCommandProcessor.js'; import { useReactToolScheduler, TrackedToolCall, @@ -20,8 +29,10 @@ import { Config, EditorType, AuthType, + GeminiClient, GeminiEventType as ServerGeminiEventType, AnyToolInvocation, + ToolErrorType, // <-- Import ToolErrorType } from '@google/gemini-cli-core'; import { Part, PartListUnion } from '@google/genai'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; @@ -83,11 +94,7 @@ vi.mock('./shellCommandProcessor.js', () => ({ }), })); -vi.mock('./atCommandProcessor.js', () => ({ - handleAtCommand: vi - .fn() - .mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }), -})); +vi.mock('./atCommandProcessor.js'); vi.mock('../utils/markdownUtilities.js', () => ({ findLastSafeSplitPoint: vi.fn((s: string) => s.length), @@ -259,6 +266,7 @@ describe('useGeminiStream', () => { let mockScheduleToolCalls: Mock; let mockCancelAllToolCalls: Mock; let mockMarkToolsAsSubmitted: Mock; + let handleAtCommandSpy: MockInstance; beforeEach(() => { vi.clearAllMocks(); // Clear mocks before each test @@ -342,6 +350,7 @@ describe('useGeminiStream', () => { mockSendMessageStream .mockClear() .mockReturnValue((async function* () {})()); + handleAtCommandSpy = vi.spyOn(atCommandProcessor, 'handleAtCommand'); }); const mockLoadedSettings: LoadedSettings = { @@ -447,6 +456,7 @@ describe('useGeminiStream', () => { callId: 'call1', responseParts: [{ text: 'tool 1 response' }], error: undefined, + errorType: undefined, // FIX: Added missing property resultDisplay: 'Tool 1 success display', }, tool: { @@ -512,7 +522,11 @@ describe('useGeminiStream', () => { }, status: 'success', responseSubmittedToGemini: false, - response: { callId: 'call1', responseParts: toolCall1ResponseParts }, + response: { + callId: 'call1', + responseParts: toolCall1ResponseParts, + errorType: undefined, // FIX: Added missing property + }, tool: { displayName: 'MockTool', }, @@ -530,7 +544,11 @@ describe('useGeminiStream', () => { }, status: 'error', responseSubmittedToGemini: false, - response: { callId: 'call2', responseParts: toolCall2ResponseParts }, + response: { + callId: 'call2', + responseParts: toolCall2ResponseParts, + errorType: ToolErrorType.UNHANDLED_EXCEPTION, // FIX: Added missing property + }, } as TrackedCompletedToolCall, // Treat error as a form of completion for submission ]; @@ -597,7 +615,11 @@ describe('useGeminiStream', () => { prompt_id: 'prompt-id-3', }, status: 'cancelled', - response: { callId: '1', responseParts: [{ text: 'cancelled' }] }, + response: { + callId: '1', + responseParts: [{ text: 'cancelled' }], + errorType: undefined, // FIX: Added missing property + }, responseSubmittedToGemini: false, tool: { displayName: 'mock tool', @@ -682,6 +704,7 @@ describe('useGeminiStream', () => { ], resultDisplay: undefined, error: undefined, + errorType: undefined, // FIX: Added missing property }, responseSubmittedToGemini: false, }; @@ -710,6 +733,7 @@ describe('useGeminiStream', () => { ], resultDisplay: undefined, error: undefined, + errorType: undefined, // FIX: Added missing property }, responseSubmittedToGemini: false, }; @@ -812,6 +836,7 @@ describe('useGeminiStream', () => { callId: 'call1', responseParts: toolCallResponseParts, error: undefined, + errorType: undefined, // FIX: Added missing property resultDisplay: 'Tool 1 success display', }, endTime: Date.now(), @@ -1214,6 +1239,7 @@ describe('useGeminiStream', () => { responseParts: [{ text: 'Memory saved' }], resultDisplay: 'Success: Memory saved', error: undefined, + errorType: undefined, // FIX: Added missing property }, tool: { name: 'save_memory', @@ -1757,4 +1783,68 @@ describe('useGeminiStream', () => { ); }); }); + + it('should process @include commands, adding user turn after processing to prevent race conditions', async () => { + const rawQuery = '@include file.txt Summarize this.'; + const processedQueryParts = [ + { text: 'Summarize this with content from @file.txt' }, + { text: 'File content...' }, + ]; + const userMessageTimestamp = Date.now(); + vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp); + + // Mock the behavior of handleAtCommand + handleAtCommandSpy.mockResolvedValue({ + processedQuery: processedQueryParts, + shouldProceed: true, + }); + + const { result } = renderHook(() => + useGeminiStream( + mockConfig.getGeminiClient() as GeminiClient, + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, // shellModeActive + vi.fn(), // getPreferredEditor + vi.fn(), // onAuthError + vi.fn(), // performMemoryRefresh + false, // modelSwitched + vi.fn(), // setModelSwitched + vi.fn(), // onEditorClose + vi.fn(), // onCancelSubmit + ), + ); + + // Act: Submit the query + await act(async () => { + await result.current.submitQuery(rawQuery); + }); + + // Assert + // 1. Verify handleAtCommand was called with the raw query. + expect(handleAtCommandSpy).toHaveBeenCalledWith( + expect.objectContaining({ + query: rawQuery, + }), + ); + + // 2. Verify the user's turn was added to history *after* processing. + expect(mockAddItem).toHaveBeenCalledWith( + { + type: MessageType.USER, + text: rawQuery, + }, + userMessageTimestamp, + ); + + // 3. Verify the *processed* query was sent to the model, not the raw one. + expect(mockSendMessageStream).toHaveBeenCalledWith( + processedQueryParts, + expect.any(AbortSignal), + expect.any(String), + ); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index abfe28c7..45344c73 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -307,6 +307,13 @@ export const useGeminiStream = ( messageId: userMessageTimestamp, signal: abortSignal, }); + + // Add user's turn after @ command processing is done. + addItem( + { type: MessageType.USER, text: trimmedQuery }, + userMessageTimestamp, + ); + if (!atCommandResult.shouldProceed) { return { queryToSend: null, shouldProceed: false }; }