bug: fix cancel after a tool has been used (#1270)
This commit is contained in:
parent
1d32313a30
commit
52afcb3a12
|
@ -8,6 +8,7 @@
|
||||||
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
|
||||||
import { renderHook, act, waitFor } from '@testing-library/react';
|
import { renderHook, act, waitFor } from '@testing-library/react';
|
||||||
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
|
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
|
||||||
|
import { useInput } from 'ink';
|
||||||
import {
|
import {
|
||||||
useReactToolScheduler,
|
useReactToolScheduler,
|
||||||
TrackedToolCall,
|
TrackedToolCall,
|
||||||
|
@ -18,7 +19,7 @@ import {
|
||||||
import { Config, EditorType } from '@gemini-cli/core';
|
import { Config, EditorType } from '@gemini-cli/core';
|
||||||
import { Part, PartListUnion } from '@google/genai';
|
import { Part, PartListUnion } from '@google/genai';
|
||||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||||
import { HistoryItem, StreamingState } from '../types.js';
|
import { HistoryItem, MessageType, StreamingState } from '../types.js';
|
||||||
import { Dispatch, SetStateAction } from 'react';
|
import { Dispatch, SetStateAction } from 'react';
|
||||||
import { LoadedSettings } from '../../config/settings.js';
|
import { LoadedSettings } from '../../config/settings.js';
|
||||||
|
|
||||||
|
@ -727,4 +728,150 @@ describe('useGeminiStream', () => {
|
||||||
// 5. After submission, the state should remain Responding.
|
// 5. After submission, the state should remain Responding.
|
||||||
expect(result.current.streamingState).toBe(StreamingState.Responding);
|
expect(result.current.streamingState).toBe(StreamingState.Responding);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('User Cancellation', () => {
|
||||||
|
let useInputCallback: (input: string, key: any) => void;
|
||||||
|
const mockUseInput = useInput as Mock;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// Capture the callback passed to useInput
|
||||||
|
mockUseInput.mockImplementation((callback) => {
|
||||||
|
useInputCallback = callback;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
const simulateEscapeKeyPress = () => {
|
||||||
|
act(() => {
|
||||||
|
useInputCallback('', { escape: true });
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
it('should cancel an in-progress stream when escape is pressed', async () => {
|
||||||
|
const mockStream = (async function* () {
|
||||||
|
yield { type: 'content', value: 'Part 1' };
|
||||||
|
// Keep the stream open
|
||||||
|
await new Promise(() => {});
|
||||||
|
})();
|
||||||
|
mockSendMessageStream.mockReturnValue(mockStream);
|
||||||
|
|
||||||
|
const { result } = renderTestHook();
|
||||||
|
|
||||||
|
// Start a query
|
||||||
|
await act(async () => {
|
||||||
|
result.current.submitQuery('test query');
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for the first part of the response
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.streamingState).toBe(StreamingState.Responding);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Simulate escape key press
|
||||||
|
simulateEscapeKeyPress();
|
||||||
|
|
||||||
|
// Verify cancellation message is added
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockAddItem).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
type: MessageType.INFO,
|
||||||
|
text: 'Request cancelled.',
|
||||||
|
},
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify state is reset
|
||||||
|
expect(result.current.streamingState).toBe(StreamingState.Idle);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not do anything if escape is pressed when not responding', () => {
|
||||||
|
const { result } = renderTestHook();
|
||||||
|
|
||||||
|
expect(result.current.streamingState).toBe(StreamingState.Idle);
|
||||||
|
|
||||||
|
// Simulate escape key press
|
||||||
|
simulateEscapeKeyPress();
|
||||||
|
|
||||||
|
// No change should happen, no cancellation message
|
||||||
|
expect(mockAddItem).not.toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
text: 'Request cancelled.',
|
||||||
|
}),
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should prevent further processing after cancellation', async () => {
|
||||||
|
let continueStream: () => void;
|
||||||
|
const streamPromise = new Promise<void>((resolve) => {
|
||||||
|
continueStream = resolve;
|
||||||
|
});
|
||||||
|
|
||||||
|
const mockStream = (async function* () {
|
||||||
|
yield { type: 'content', value: 'Initial' };
|
||||||
|
await streamPromise; // Wait until we manually continue
|
||||||
|
yield { type: 'content', value: ' Canceled' };
|
||||||
|
})();
|
||||||
|
mockSendMessageStream.mockReturnValue(mockStream);
|
||||||
|
|
||||||
|
const { result } = renderTestHook();
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
result.current.submitQuery('long running query');
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(result.current.streamingState).toBe(StreamingState.Responding);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Cancel the request
|
||||||
|
simulateEscapeKeyPress();
|
||||||
|
|
||||||
|
// Allow the stream to continue
|
||||||
|
act(() => {
|
||||||
|
continueStream();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait a bit to see if the second part is processed
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||||
|
|
||||||
|
// The text should not have been updated with " Canceled"
|
||||||
|
const lastCall = mockAddItem.mock.calls.find(
|
||||||
|
(call) => call[0].type === 'gemini',
|
||||||
|
);
|
||||||
|
expect(lastCall?.[0].text).toBe('Initial');
|
||||||
|
|
||||||
|
// The final state should be idle after cancellation
|
||||||
|
expect(result.current.streamingState).toBe(StreamingState.Idle);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not cancel if a tool call is in progress (not just responding)', async () => {
|
||||||
|
const toolCalls: TrackedToolCall[] = [
|
||||||
|
{
|
||||||
|
request: { callId: 'call1', name: 'tool1', args: {} },
|
||||||
|
status: 'executing',
|
||||||
|
responseSubmittedToGemini: false,
|
||||||
|
tool: {
|
||||||
|
name: 'tool1',
|
||||||
|
description: 'desc1',
|
||||||
|
getDescription: vi.fn(),
|
||||||
|
} as any,
|
||||||
|
startTime: Date.now(),
|
||||||
|
liveOutput: '...',
|
||||||
|
} as TrackedExecutingToolCall,
|
||||||
|
];
|
||||||
|
|
||||||
|
const abortSpy = vi.spyOn(AbortController.prototype, 'abort');
|
||||||
|
const { result } = renderTestHook(toolCalls);
|
||||||
|
|
||||||
|
// State is `Responding` because a tool is running
|
||||||
|
expect(result.current.streamingState).toBe(StreamingState.Responding);
|
||||||
|
|
||||||
|
// Try to cancel
|
||||||
|
simulateEscapeKeyPress();
|
||||||
|
|
||||||
|
// Nothing should happen because the state is not `Responding`
|
||||||
|
expect(abortSpy).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -92,6 +92,7 @@ export const useGeminiStream = (
|
||||||
) => {
|
) => {
|
||||||
const [initError, setInitError] = useState<string | null>(null);
|
const [initError, setInitError] = useState<string | null>(null);
|
||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
|
const turnCancelledRef = useRef(false);
|
||||||
const [isResponding, setIsResponding] = useState<boolean>(false);
|
const [isResponding, setIsResponding] = useState<boolean>(false);
|
||||||
const [thought, setThought] = useState<ThoughtSummary | null>(null);
|
const [thought, setThought] = useState<ThoughtSummary | null>(null);
|
||||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||||
|
@ -168,15 +169,25 @@ export const useGeminiStream = (
|
||||||
return StreamingState.Idle;
|
return StreamingState.Idle;
|
||||||
}, [isResponding, toolCalls]);
|
}, [isResponding, toolCalls]);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (streamingState === StreamingState.Idle) {
|
|
||||||
abortControllerRef.current = null;
|
|
||||||
}
|
|
||||||
}, [streamingState]);
|
|
||||||
|
|
||||||
useInput((_input, key) => {
|
useInput((_input, key) => {
|
||||||
if (streamingState !== StreamingState.Idle && key.escape) {
|
if (streamingState === StreamingState.Responding && key.escape) {
|
||||||
|
if (turnCancelledRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
turnCancelledRef.current = true;
|
||||||
abortControllerRef.current?.abort();
|
abortControllerRef.current?.abort();
|
||||||
|
if (pendingHistoryItemRef.current) {
|
||||||
|
addItem(pendingHistoryItemRef.current, Date.now());
|
||||||
|
}
|
||||||
|
addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.INFO,
|
||||||
|
text: 'Request cancelled.',
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
setPendingHistoryItem(null);
|
||||||
|
setIsResponding(false);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -189,6 +200,9 @@ export const useGeminiStream = (
|
||||||
queryToSend: PartListUnion | null;
|
queryToSend: PartListUnion | null;
|
||||||
shouldProceed: boolean;
|
shouldProceed: boolean;
|
||||||
}> => {
|
}> => {
|
||||||
|
if (turnCancelledRef.current) {
|
||||||
|
return { queryToSend: null, shouldProceed: false };
|
||||||
|
}
|
||||||
if (typeof query === 'string' && query.trim().length === 0) {
|
if (typeof query === 'string' && query.trim().length === 0) {
|
||||||
return { queryToSend: null, shouldProceed: false };
|
return { queryToSend: null, shouldProceed: false };
|
||||||
}
|
}
|
||||||
|
@ -285,6 +299,10 @@ export const useGeminiStream = (
|
||||||
currentGeminiMessageBuffer: string,
|
currentGeminiMessageBuffer: string,
|
||||||
userMessageTimestamp: number,
|
userMessageTimestamp: number,
|
||||||
): string => {
|
): string => {
|
||||||
|
if (turnCancelledRef.current) {
|
||||||
|
// Prevents additional output after a user initiated cancel.
|
||||||
|
return '';
|
||||||
|
}
|
||||||
let newGeminiMessageBuffer = currentGeminiMessageBuffer + eventValue;
|
let newGeminiMessageBuffer = currentGeminiMessageBuffer + eventValue;
|
||||||
if (
|
if (
|
||||||
pendingHistoryItemRef.current?.type !== 'gemini' &&
|
pendingHistoryItemRef.current?.type !== 'gemini' &&
|
||||||
|
@ -335,6 +353,9 @@ export const useGeminiStream = (
|
||||||
|
|
||||||
const handleUserCancelledEvent = useCallback(
|
const handleUserCancelledEvent = useCallback(
|
||||||
(userMessageTimestamp: number) => {
|
(userMessageTimestamp: number) => {
|
||||||
|
if (turnCancelledRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (pendingHistoryItemRef.current) {
|
if (pendingHistoryItemRef.current) {
|
||||||
if (pendingHistoryItemRef.current.type === 'tool_group') {
|
if (pendingHistoryItemRef.current.type === 'tool_group') {
|
||||||
const updatedTools = pendingHistoryItemRef.current.tools.map(
|
const updatedTools = pendingHistoryItemRef.current.tools.map(
|
||||||
|
@ -469,6 +490,7 @@ export const useGeminiStream = (
|
||||||
|
|
||||||
abortControllerRef.current = new AbortController();
|
abortControllerRef.current = new AbortController();
|
||||||
const abortSignal = abortControllerRef.current.signal;
|
const abortSignal = abortControllerRef.current.signal;
|
||||||
|
turnCancelledRef.current = false;
|
||||||
|
|
||||||
const { queryToSend, shouldProceed } = await prepareQueryForGemini(
|
const { queryToSend, shouldProceed } = await prepareQueryForGemini(
|
||||||
query,
|
query,
|
||||||
|
|
Loading…
Reference in New Issue