From 52afcb3a1233237b07aa86b1678f4c4eded70800 Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Fri, 20 Jun 2025 23:01:44 -0400 Subject: [PATCH] bug: fix cancel after a tool has been used (#1270) --- .../cli/src/ui/hooks/useGeminiStream.test.tsx | 149 +++++++++++++++++- packages/cli/src/ui/hooks/useGeminiStream.ts | 36 ++++- 2 files changed, 177 insertions(+), 8 deletions(-) diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 487738c3..ac168dcd 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -8,6 +8,7 @@ import { describe, it, expect, vi, beforeEach, Mock } from 'vitest'; import { renderHook, act, waitFor } from '@testing-library/react'; import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js'; +import { useInput } from 'ink'; import { useReactToolScheduler, TrackedToolCall, @@ -18,7 +19,7 @@ import { import { Config, EditorType } from '@gemini-cli/core'; import { Part, PartListUnion } from '@google/genai'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; -import { HistoryItem, StreamingState } from '../types.js'; +import { HistoryItem, MessageType, StreamingState } from '../types.js'; import { Dispatch, SetStateAction } from 'react'; import { LoadedSettings } from '../../config/settings.js'; @@ -727,4 +728,150 @@ describe('useGeminiStream', () => { // 5. After submission, the state should remain Responding. expect(result.current.streamingState).toBe(StreamingState.Responding); }); + + describe('User Cancellation', () => { + let useInputCallback: (input: string, key: any) => void; + const mockUseInput = useInput as Mock; + + beforeEach(() => { + // Capture the callback passed to useInput + mockUseInput.mockImplementation((callback) => { + useInputCallback = callback; + }); + }); + + const simulateEscapeKeyPress = () => { + act(() => { + useInputCallback('', { escape: true }); + }); + }; + + it('should cancel an in-progress stream when escape is pressed', async () => { + const mockStream = (async function* () { + yield { type: 'content', value: 'Part 1' }; + // Keep the stream open + await new Promise(() => {}); + })(); + mockSendMessageStream.mockReturnValue(mockStream); + + const { result } = renderTestHook(); + + // Start a query + await act(async () => { + result.current.submitQuery('test query'); + }); + + // Wait for the first part of the response + await waitFor(() => { + expect(result.current.streamingState).toBe(StreamingState.Responding); + }); + + // Simulate escape key press + simulateEscapeKeyPress(); + + // Verify cancellation message is added + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: 'Request cancelled.', + }, + expect.any(Number), + ); + }); + + // Verify state is reset + expect(result.current.streamingState).toBe(StreamingState.Idle); + }); + + it('should not do anything if escape is pressed when not responding', () => { + const { result } = renderTestHook(); + + expect(result.current.streamingState).toBe(StreamingState.Idle); + + // Simulate escape key press + simulateEscapeKeyPress(); + + // No change should happen, no cancellation message + expect(mockAddItem).not.toHaveBeenCalledWith( + expect.objectContaining({ + text: 'Request cancelled.', + }), + expect.any(Number), + ); + }); + + it('should prevent further processing after cancellation', async () => { + let continueStream: () => void; + const streamPromise = new Promise((resolve) => { + continueStream = resolve; + }); + + const mockStream = (async function* () { + yield { type: 'content', value: 'Initial' }; + await streamPromise; // Wait until we manually continue + yield { type: 'content', value: ' Canceled' }; + })(); + mockSendMessageStream.mockReturnValue(mockStream); + + const { result } = renderTestHook(); + + await act(async () => { + result.current.submitQuery('long running query'); + }); + + await waitFor(() => { + expect(result.current.streamingState).toBe(StreamingState.Responding); + }); + + // Cancel the request + simulateEscapeKeyPress(); + + // Allow the stream to continue + act(() => { + continueStream(); + }); + + // Wait a bit to see if the second part is processed + await new Promise((resolve) => setTimeout(resolve, 50)); + + // The text should not have been updated with " Canceled" + const lastCall = mockAddItem.mock.calls.find( + (call) => call[0].type === 'gemini', + ); + expect(lastCall?.[0].text).toBe('Initial'); + + // The final state should be idle after cancellation + expect(result.current.streamingState).toBe(StreamingState.Idle); + }); + + it('should not cancel if a tool call is in progress (not just responding)', async () => { + const toolCalls: TrackedToolCall[] = [ + { + request: { callId: 'call1', name: 'tool1', args: {} }, + status: 'executing', + responseSubmittedToGemini: false, + tool: { + name: 'tool1', + description: 'desc1', + getDescription: vi.fn(), + } as any, + startTime: Date.now(), + liveOutput: '...', + } as TrackedExecutingToolCall, + ]; + + const abortSpy = vi.spyOn(AbortController.prototype, 'abort'); + const { result } = renderTestHook(toolCalls); + + // State is `Responding` because a tool is running + expect(result.current.streamingState).toBe(StreamingState.Responding); + + // Try to cancel + simulateEscapeKeyPress(); + + // Nothing should happen because the state is not `Responding` + expect(abortSpy).not.toHaveBeenCalled(); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 234652db..fcfa1c57 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -92,6 +92,7 @@ export const useGeminiStream = ( ) => { const [initError, setInitError] = useState(null); const abortControllerRef = useRef(null); + const turnCancelledRef = useRef(false); const [isResponding, setIsResponding] = useState(false); const [thought, setThought] = useState(null); const [pendingHistoryItemRef, setPendingHistoryItem] = @@ -168,15 +169,25 @@ export const useGeminiStream = ( return StreamingState.Idle; }, [isResponding, toolCalls]); - useEffect(() => { - if (streamingState === StreamingState.Idle) { - abortControllerRef.current = null; - } - }, [streamingState]); - useInput((_input, key) => { - if (streamingState !== StreamingState.Idle && key.escape) { + if (streamingState === StreamingState.Responding && key.escape) { + if (turnCancelledRef.current) { + return; + } + turnCancelledRef.current = true; abortControllerRef.current?.abort(); + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, Date.now()); + } + addItem( + { + type: MessageType.INFO, + text: 'Request cancelled.', + }, + Date.now(), + ); + setPendingHistoryItem(null); + setIsResponding(false); } }); @@ -189,6 +200,9 @@ export const useGeminiStream = ( queryToSend: PartListUnion | null; shouldProceed: boolean; }> => { + if (turnCancelledRef.current) { + return { queryToSend: null, shouldProceed: false }; + } if (typeof query === 'string' && query.trim().length === 0) { return { queryToSend: null, shouldProceed: false }; } @@ -285,6 +299,10 @@ export const useGeminiStream = ( currentGeminiMessageBuffer: string, userMessageTimestamp: number, ): string => { + if (turnCancelledRef.current) { + // Prevents additional output after a user initiated cancel. + return ''; + } let newGeminiMessageBuffer = currentGeminiMessageBuffer + eventValue; if ( pendingHistoryItemRef.current?.type !== 'gemini' && @@ -335,6 +353,9 @@ export const useGeminiStream = ( const handleUserCancelledEvent = useCallback( (userMessageTimestamp: number) => { + if (turnCancelledRef.current) { + return; + } if (pendingHistoryItemRef.current) { if (pendingHistoryItemRef.current.type === 'tool_group') { const updatedTools = pendingHistoryItemRef.current.tools.map( @@ -469,6 +490,7 @@ export const useGeminiStream = ( abortControllerRef.current = new AbortController(); const abortSignal = abortControllerRef.current.signal; + turnCancelledRef.current = false; const { queryToSend, shouldProceed } = await prepareQueryForGemini( query,