Fixing at command race condition (#6663)
This commit is contained in:
parent
52e340a11b
commit
4642de2a5c
|
@ -98,10 +98,6 @@ describe('handleAtCommand', () => {
|
||||||
processedQuery: [{ text: query }],
|
processedQuery: [{ text: query }],
|
||||||
shouldProceed: true,
|
shouldProceed: true,
|
||||||
});
|
});
|
||||||
expect(mockAddItem).toHaveBeenCalledWith(
|
|
||||||
{ type: 'user', text: query },
|
|
||||||
123,
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should pass through original query if only a lone @ symbol is present', async () => {
|
it('should pass through original query if only a lone @ symbol is present', async () => {
|
||||||
|
@ -120,10 +116,6 @@ describe('handleAtCommand', () => {
|
||||||
processedQuery: [{ text: queryWithSpaces }],
|
processedQuery: [{ text: queryWithSpaces }],
|
||||||
shouldProceed: true,
|
shouldProceed: true,
|
||||||
});
|
});
|
||||||
expect(mockAddItem).toHaveBeenCalledWith(
|
|
||||||
{ type: 'user', text: queryWithSpaces },
|
|
||||||
124,
|
|
||||||
);
|
|
||||||
expect(mockOnDebugMessage).toHaveBeenCalledWith(
|
expect(mockOnDebugMessage).toHaveBeenCalledWith(
|
||||||
'Lone @ detected, will be treated as text in the modified query.',
|
'Lone @ detected, will be treated as text in the modified query.',
|
||||||
);
|
);
|
||||||
|
@ -156,10 +148,6 @@ describe('handleAtCommand', () => {
|
||||||
],
|
],
|
||||||
shouldProceed: true,
|
shouldProceed: true,
|
||||||
});
|
});
|
||||||
expect(mockAddItem).toHaveBeenCalledWith(
|
|
||||||
{ type: 'user', text: query },
|
|
||||||
125,
|
|
||||||
);
|
|
||||||
expect(mockAddItem).toHaveBeenCalledWith(
|
expect(mockAddItem).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
type: 'tool_group',
|
type: 'tool_group',
|
||||||
|
@ -198,10 +186,6 @@ describe('handleAtCommand', () => {
|
||||||
],
|
],
|
||||||
shouldProceed: true,
|
shouldProceed: true,
|
||||||
});
|
});
|
||||||
expect(mockAddItem).toHaveBeenCalledWith(
|
|
||||||
{ type: 'user', text: query },
|
|
||||||
126,
|
|
||||||
);
|
|
||||||
expect(mockOnDebugMessage).toHaveBeenCalledWith(
|
expect(mockOnDebugMessage).toHaveBeenCalledWith(
|
||||||
`Path ${dirPath} resolved to directory, using glob: ${resolvedGlob}`,
|
`Path ${dirPath} resolved to directory, using glob: ${resolvedGlob}`,
|
||||||
);
|
);
|
||||||
|
@ -236,10 +220,6 @@ describe('handleAtCommand', () => {
|
||||||
],
|
],
|
||||||
shouldProceed: true,
|
shouldProceed: true,
|
||||||
});
|
});
|
||||||
expect(mockAddItem).toHaveBeenCalledWith(
|
|
||||||
{ type: 'user', text: query },
|
|
||||||
128,
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should correctly unescape paths with escaped spaces', async () => {
|
it('should correctly unescape paths with escaped spaces', async () => {
|
||||||
|
@ -270,10 +250,6 @@ describe('handleAtCommand', () => {
|
||||||
],
|
],
|
||||||
shouldProceed: true,
|
shouldProceed: true,
|
||||||
});
|
});
|
||||||
expect(mockAddItem).toHaveBeenCalledWith(
|
|
||||||
{ type: 'user', text: query },
|
|
||||||
125,
|
|
||||||
);
|
|
||||||
expect(mockAddItem).toHaveBeenCalledWith(
|
expect(mockAddItem).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
type: 'tool_group',
|
type: 'tool_group',
|
||||||
|
@ -1090,4 +1066,37 @@ describe('handleAtCommand', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("should not add the user's turn to history, as that is the caller's responsibility", async () => {
|
||||||
|
// Arrange
|
||||||
|
const fileContent = 'This is the file content.';
|
||||||
|
const filePath = await createTestFile(
|
||||||
|
path.join(testRootDir, 'path', 'to', 'another-file.txt'),
|
||||||
|
fileContent,
|
||||||
|
);
|
||||||
|
const query = `A query with @${filePath}`;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
await handleAtCommand({
|
||||||
|
query,
|
||||||
|
config: mockConfig,
|
||||||
|
addItem: mockAddItem,
|
||||||
|
onDebugMessage: mockOnDebugMessage,
|
||||||
|
messageId: 999,
|
||||||
|
signal: abortController.signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
// It SHOULD be called for the tool_group
|
||||||
|
expect(mockAddItem).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ type: 'tool_group' }),
|
||||||
|
999,
|
||||||
|
);
|
||||||
|
|
||||||
|
// It should NOT have been called for the user turn
|
||||||
|
const userTurnCalls = mockAddItem.mock.calls.filter(
|
||||||
|
(call) => call[0].type === 'user',
|
||||||
|
);
|
||||||
|
expect(userTurnCalls).toHaveLength(0);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -137,12 +137,9 @@ export async function handleAtCommand({
|
||||||
);
|
);
|
||||||
|
|
||||||
if (atPathCommandParts.length === 0) {
|
if (atPathCommandParts.length === 0) {
|
||||||
addItem({ type: 'user', text: query }, userMessageTimestamp);
|
|
||||||
return { processedQuery: [{ text: query }], shouldProceed: true };
|
return { processedQuery: [{ text: query }], shouldProceed: true };
|
||||||
}
|
}
|
||||||
|
|
||||||
addItem({ type: 'user', text: query }, userMessageTimestamp);
|
|
||||||
|
|
||||||
// Get centralized file discovery service
|
// Get centralized file discovery service
|
||||||
const fileDiscovery = config.getFileService();
|
const fileDiscovery = config.getFileService();
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,19 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
|
import {
|
||||||
|
describe,
|
||||||
|
it,
|
||||||
|
expect,
|
||||||
|
vi,
|
||||||
|
beforeEach,
|
||||||
|
Mock,
|
||||||
|
MockInstance,
|
||||||
|
} from 'vitest';
|
||||||
import { renderHook, act, waitFor } from '@testing-library/react';
|
import { renderHook, act, waitFor } from '@testing-library/react';
|
||||||
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
|
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
|
||||||
import { useKeypress } from './useKeypress.js';
|
import { useKeypress } from './useKeypress.js';
|
||||||
|
import * as atCommandProcessor from './atCommandProcessor.js';
|
||||||
import {
|
import {
|
||||||
useReactToolScheduler,
|
useReactToolScheduler,
|
||||||
TrackedToolCall,
|
TrackedToolCall,
|
||||||
|
@ -20,8 +29,10 @@ import {
|
||||||
Config,
|
Config,
|
||||||
EditorType,
|
EditorType,
|
||||||
AuthType,
|
AuthType,
|
||||||
|
GeminiClient,
|
||||||
GeminiEventType as ServerGeminiEventType,
|
GeminiEventType as ServerGeminiEventType,
|
||||||
AnyToolInvocation,
|
AnyToolInvocation,
|
||||||
|
ToolErrorType, // <-- Import ToolErrorType
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
import { Part, PartListUnion } from '@google/genai';
|
import { Part, PartListUnion } from '@google/genai';
|
||||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||||
|
@ -83,11 +94,7 @@ vi.mock('./shellCommandProcessor.js', () => ({
|
||||||
}),
|
}),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
vi.mock('./atCommandProcessor.js', () => ({
|
vi.mock('./atCommandProcessor.js');
|
||||||
handleAtCommand: vi
|
|
||||||
.fn()
|
|
||||||
.mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }),
|
|
||||||
}));
|
|
||||||
|
|
||||||
vi.mock('../utils/markdownUtilities.js', () => ({
|
vi.mock('../utils/markdownUtilities.js', () => ({
|
||||||
findLastSafeSplitPoint: vi.fn((s: string) => s.length),
|
findLastSafeSplitPoint: vi.fn((s: string) => s.length),
|
||||||
|
@ -259,6 +266,7 @@ describe('useGeminiStream', () => {
|
||||||
let mockScheduleToolCalls: Mock;
|
let mockScheduleToolCalls: Mock;
|
||||||
let mockCancelAllToolCalls: Mock;
|
let mockCancelAllToolCalls: Mock;
|
||||||
let mockMarkToolsAsSubmitted: Mock;
|
let mockMarkToolsAsSubmitted: Mock;
|
||||||
|
let handleAtCommandSpy: MockInstance;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks(); // Clear mocks before each test
|
vi.clearAllMocks(); // Clear mocks before each test
|
||||||
|
@ -342,6 +350,7 @@ describe('useGeminiStream', () => {
|
||||||
mockSendMessageStream
|
mockSendMessageStream
|
||||||
.mockClear()
|
.mockClear()
|
||||||
.mockReturnValue((async function* () {})());
|
.mockReturnValue((async function* () {})());
|
||||||
|
handleAtCommandSpy = vi.spyOn(atCommandProcessor, 'handleAtCommand');
|
||||||
});
|
});
|
||||||
|
|
||||||
const mockLoadedSettings: LoadedSettings = {
|
const mockLoadedSettings: LoadedSettings = {
|
||||||
|
@ -447,6 +456,7 @@ describe('useGeminiStream', () => {
|
||||||
callId: 'call1',
|
callId: 'call1',
|
||||||
responseParts: [{ text: 'tool 1 response' }],
|
responseParts: [{ text: 'tool 1 response' }],
|
||||||
error: undefined,
|
error: undefined,
|
||||||
|
errorType: undefined, // FIX: Added missing property
|
||||||
resultDisplay: 'Tool 1 success display',
|
resultDisplay: 'Tool 1 success display',
|
||||||
},
|
},
|
||||||
tool: {
|
tool: {
|
||||||
|
@ -512,7 +522,11 @@ describe('useGeminiStream', () => {
|
||||||
},
|
},
|
||||||
status: 'success',
|
status: 'success',
|
||||||
responseSubmittedToGemini: false,
|
responseSubmittedToGemini: false,
|
||||||
response: { callId: 'call1', responseParts: toolCall1ResponseParts },
|
response: {
|
||||||
|
callId: 'call1',
|
||||||
|
responseParts: toolCall1ResponseParts,
|
||||||
|
errorType: undefined, // FIX: Added missing property
|
||||||
|
},
|
||||||
tool: {
|
tool: {
|
||||||
displayName: 'MockTool',
|
displayName: 'MockTool',
|
||||||
},
|
},
|
||||||
|
@ -530,7 +544,11 @@ describe('useGeminiStream', () => {
|
||||||
},
|
},
|
||||||
status: 'error',
|
status: 'error',
|
||||||
responseSubmittedToGemini: false,
|
responseSubmittedToGemini: false,
|
||||||
response: { callId: 'call2', responseParts: toolCall2ResponseParts },
|
response: {
|
||||||
|
callId: 'call2',
|
||||||
|
responseParts: toolCall2ResponseParts,
|
||||||
|
errorType: ToolErrorType.UNHANDLED_EXCEPTION, // FIX: Added missing property
|
||||||
|
},
|
||||||
} as TrackedCompletedToolCall, // Treat error as a form of completion for submission
|
} as TrackedCompletedToolCall, // Treat error as a form of completion for submission
|
||||||
];
|
];
|
||||||
|
|
||||||
|
@ -597,7 +615,11 @@ describe('useGeminiStream', () => {
|
||||||
prompt_id: 'prompt-id-3',
|
prompt_id: 'prompt-id-3',
|
||||||
},
|
},
|
||||||
status: 'cancelled',
|
status: 'cancelled',
|
||||||
response: { callId: '1', responseParts: [{ text: 'cancelled' }] },
|
response: {
|
||||||
|
callId: '1',
|
||||||
|
responseParts: [{ text: 'cancelled' }],
|
||||||
|
errorType: undefined, // FIX: Added missing property
|
||||||
|
},
|
||||||
responseSubmittedToGemini: false,
|
responseSubmittedToGemini: false,
|
||||||
tool: {
|
tool: {
|
||||||
displayName: 'mock tool',
|
displayName: 'mock tool',
|
||||||
|
@ -682,6 +704,7 @@ describe('useGeminiStream', () => {
|
||||||
],
|
],
|
||||||
resultDisplay: undefined,
|
resultDisplay: undefined,
|
||||||
error: undefined,
|
error: undefined,
|
||||||
|
errorType: undefined, // FIX: Added missing property
|
||||||
},
|
},
|
||||||
responseSubmittedToGemini: false,
|
responseSubmittedToGemini: false,
|
||||||
};
|
};
|
||||||
|
@ -710,6 +733,7 @@ describe('useGeminiStream', () => {
|
||||||
],
|
],
|
||||||
resultDisplay: undefined,
|
resultDisplay: undefined,
|
||||||
error: undefined,
|
error: undefined,
|
||||||
|
errorType: undefined, // FIX: Added missing property
|
||||||
},
|
},
|
||||||
responseSubmittedToGemini: false,
|
responseSubmittedToGemini: false,
|
||||||
};
|
};
|
||||||
|
@ -812,6 +836,7 @@ describe('useGeminiStream', () => {
|
||||||
callId: 'call1',
|
callId: 'call1',
|
||||||
responseParts: toolCallResponseParts,
|
responseParts: toolCallResponseParts,
|
||||||
error: undefined,
|
error: undefined,
|
||||||
|
errorType: undefined, // FIX: Added missing property
|
||||||
resultDisplay: 'Tool 1 success display',
|
resultDisplay: 'Tool 1 success display',
|
||||||
},
|
},
|
||||||
endTime: Date.now(),
|
endTime: Date.now(),
|
||||||
|
@ -1214,6 +1239,7 @@ describe('useGeminiStream', () => {
|
||||||
responseParts: [{ text: 'Memory saved' }],
|
responseParts: [{ text: 'Memory saved' }],
|
||||||
resultDisplay: 'Success: Memory saved',
|
resultDisplay: 'Success: Memory saved',
|
||||||
error: undefined,
|
error: undefined,
|
||||||
|
errorType: undefined, // FIX: Added missing property
|
||||||
},
|
},
|
||||||
tool: {
|
tool: {
|
||||||
name: 'save_memory',
|
name: 'save_memory',
|
||||||
|
@ -1757,4 +1783,68 @@ describe('useGeminiStream', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should process @include commands, adding user turn after processing to prevent race conditions', async () => {
|
||||||
|
const rawQuery = '@include file.txt Summarize this.';
|
||||||
|
const processedQueryParts = [
|
||||||
|
{ text: 'Summarize this with content from @file.txt' },
|
||||||
|
{ text: 'File content...' },
|
||||||
|
];
|
||||||
|
const userMessageTimestamp = Date.now();
|
||||||
|
vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp);
|
||||||
|
|
||||||
|
// Mock the behavior of handleAtCommand
|
||||||
|
handleAtCommandSpy.mockResolvedValue({
|
||||||
|
processedQuery: processedQueryParts,
|
||||||
|
shouldProceed: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
const { result } = renderHook(() =>
|
||||||
|
useGeminiStream(
|
||||||
|
mockConfig.getGeminiClient() as GeminiClient,
|
||||||
|
[],
|
||||||
|
mockAddItem,
|
||||||
|
mockConfig,
|
||||||
|
mockOnDebugMessage,
|
||||||
|
mockHandleSlashCommand,
|
||||||
|
false, // shellModeActive
|
||||||
|
vi.fn(), // getPreferredEditor
|
||||||
|
vi.fn(), // onAuthError
|
||||||
|
vi.fn(), // performMemoryRefresh
|
||||||
|
false, // modelSwitched
|
||||||
|
vi.fn(), // setModelSwitched
|
||||||
|
vi.fn(), // onEditorClose
|
||||||
|
vi.fn(), // onCancelSubmit
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act: Submit the query
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.submitQuery(rawQuery);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
// 1. Verify handleAtCommand was called with the raw query.
|
||||||
|
expect(handleAtCommandSpy).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
query: rawQuery,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Verify the user's turn was added to history *after* processing.
|
||||||
|
expect(mockAddItem).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
type: MessageType.USER,
|
||||||
|
text: rawQuery,
|
||||||
|
},
|
||||||
|
userMessageTimestamp,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 3. Verify the *processed* query was sent to the model, not the raw one.
|
||||||
|
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||||
|
processedQueryParts,
|
||||||
|
expect.any(AbortSignal),
|
||||||
|
expect.any(String),
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -307,6 +307,13 @@ export const useGeminiStream = (
|
||||||
messageId: userMessageTimestamp,
|
messageId: userMessageTimestamp,
|
||||||
signal: abortSignal,
|
signal: abortSignal,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Add user's turn after @ command processing is done.
|
||||||
|
addItem(
|
||||||
|
{ type: MessageType.USER, text: trimmedQuery },
|
||||||
|
userMessageTimestamp,
|
||||||
|
);
|
||||||
|
|
||||||
if (!atCommandResult.shouldProceed) {
|
if (!atCommandResult.shouldProceed) {
|
||||||
return { queryToSend: null, shouldProceed: false };
|
return { queryToSend: null, shouldProceed: false };
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue