From f2f2ecf9d83224778e5fc38cfcc4a1edddf9f7d4 Mon Sep 17 00:00:00 2001 From: Taylor Mullen Date: Tue, 27 May 2025 23:40:25 -0700 Subject: [PATCH] feat: Allow cancellation of in-progress Gemini requests and pre-execution checks - Implements cancellation for Gemini requests while they are actively being processed by the model. - Extends cancellation support to the logic within tools. This allows users to cancel operations during the phase where the system is determining if a tool execution requires user confirmation, which can include potentially long-running pre-flight checks or LLM-based corrections. - Underlying LLM calls for edit corrections (within and ) and next speaker checks can now also be cancelled. - Previously, cancellation of the main request was not possible until text started streaming, and pre-execution checks were not cancellable. - This change leverages the updated SDK's ability to accept an abort token and threads s throughout the request, tool execution, and pre-execution check lifecycle. Fixes https://github.com/google-gemini/gemini-cli/issues/531 --- packages/cli/src/gemini.tsx | 8 +- packages/cli/src/ui/hooks/useToolScheduler.ts | 7 +- packages/server/src/core/client.ts | 13 +++- packages/server/src/core/geminiChat.ts | 4 +- packages/server/src/core/turn.test.ts | 27 +++++-- packages/server/src/core/turn.ts | 12 ++- packages/server/src/tools/edit.test.ts | 23 ++++-- packages/server/src/tools/edit.ts | 10 ++- packages/server/src/tools/shell.ts | 1 + packages/server/src/tools/tools.ts | 3 + packages/server/src/tools/write-file.test.ts | 76 ++++++++++++++----- packages/server/src/tools/write-file.ts | 8 +- .../server/src/utils/editCorrector.test.ts | 33 ++++++-- packages/server/src/utils/editCorrector.ts | 37 ++++++++- .../src/utils/nextSpeakerChecker.test.ts | 57 +++++++++++--- .../server/src/utils/nextSpeakerChecker.ts | 2 + 16 files changed, 260 insertions(+), 61 deletions(-) diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 11875593..9cfaef37 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -95,9 +95,11 @@ async function main() { const geminiClient = new GeminiClient(config); const chat = await geminiClient.startChat(); try { - for await (const event of geminiClient.sendMessageStream(chat, [ - { text: input }, - ])) { + for await (const event of geminiClient.sendMessageStream( + chat, + [{ text: input }], + new AbortController().signal, + )) { if (event.type === 'content') { process.stdout.write(event.value); } diff --git a/packages/cli/src/ui/hooks/useToolScheduler.ts b/packages/cli/src/ui/hooks/useToolScheduler.ts index f1eee9fd..7d8cfbe4 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.ts @@ -142,7 +142,10 @@ export function useToolScheduler( const { request: r, tool } = initialCall; try { - const userApproval = await tool.shouldConfirmExecute(r.args); + const userApproval = await tool.shouldConfirmExecute( + r.args, + abortController.signal, + ); if (userApproval) { // Confirmation is needed. Update status to 'awaiting_approval'. setToolCalls( @@ -183,7 +186,7 @@ export function useToolScheduler( } }); }, - [isRunning, setToolCalls, toolRegistry], + [isRunning, setToolCalls, toolRegistry, abortController.signal], ); const cancel = useCallback( diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 341ce021..69b815ab 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -157,7 +157,7 @@ export class GeminiClient { async *sendMessageStream( chat: GeminiChat, request: PartListUnion, - signal?: AbortSignal, + signal: AbortSignal, turns: number = this.MAX_TURNS, ): AsyncGenerator { if (!turns) { @@ -169,8 +169,8 @@ export class GeminiClient { for await (const event of resultStream) { yield event; } - if (!turn.pendingToolCalls.length) { - const nextSpeakerCheck = await checkNextSpeaker(chat, this); + if (!turn.pendingToolCalls.length && signal && !signal.aborted) { + const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal); if (nextSpeakerCheck?.next_speaker === 'model') { const nextRequest = [{ text: 'Please continue.' }]; yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1); @@ -181,6 +181,7 @@ export class GeminiClient { async generateJson( contents: Content[], schema: SchemaUnion, + abortSignal: AbortSignal, model: string = 'gemini-2.0-flash', config: GenerateContentConfig = {}, ): Promise> { @@ -188,6 +189,7 @@ export class GeminiClient { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); const requestConfig = { + abortSignal, ...this.generateContentConfig, ...config, }; @@ -232,6 +234,11 @@ export class GeminiClient { ); } } catch (error) { + if (abortSignal.aborted) { + // Regular cancellation error, fail normally + throw error; + } + // Avoid double reporting for the empty response case handled above if ( error instanceof Error && diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts index c971e2cc..5ba8ce2d 100644 --- a/packages/server/src/core/geminiChat.ts +++ b/packages/server/src/core/geminiChat.ts @@ -155,7 +155,7 @@ export class GeminiChat { const responsePromise = this.modelsModule.generateContent({ model: this.model, contents: this.getHistory(true).concat(userContent), - config: params.config ?? this.config, + config: { ...this.config, ...params.config }, }); this.sendPromise = (async () => { const response = await responsePromise; @@ -219,7 +219,7 @@ export class GeminiChat { const streamResponse = this.modelsModule.generateContentStream({ model: this.model, contents: this.getHistory(true).concat(userContent), - config: params.config ?? this.config, + config: { ...this.config, ...params.config }, }); // Resolve the internal tracking of send completion promise - `sendPromise` // for both success and failure response. The actual failure is still diff --git a/packages/server/src/core/turn.test.ts b/packages/server/src/core/turn.test.ts index 44bb983f..8fb3a4c1 100644 --- a/packages/server/src/core/turn.test.ts +++ b/packages/server/src/core/turn.test.ts @@ -85,11 +85,17 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Hi' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } - expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts }); + expect(mockSendMessageStream).toHaveBeenCalledWith({ + message: reqParts, + config: { abortSignal: expect.any(AbortSignal) }, + }); expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: ' world' }, @@ -110,7 +116,10 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Use tools' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -179,7 +188,10 @@ describe('Turn', () => { mockGetHistory.mockReturnValue(historyContent); const events = []; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -210,7 +222,10 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; - for await (const event of turn.run(reqParts)) { + for await (const event of turn.run( + reqParts, + new AbortController().signal, + )) { events.push(event); } @@ -261,7 +276,7 @@ describe('Turn', () => { })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const reqParts: Part[] = [{ text: 'Hi' }]; - for await (const _ of turn.run(reqParts)) { + for await (const _ of turn.run(reqParts, new AbortController().signal)) { // consume stream } expect(turn.getDebugResponses()).toEqual([resp1, resp2]); diff --git a/packages/server/src/core/turn.ts b/packages/server/src/core/turn.ts index d5c7eb58..97e93f59 100644 --- a/packages/server/src/core/turn.ts +++ b/packages/server/src/core/turn.ts @@ -32,6 +32,7 @@ export interface ServerTool { ): Promise; shouldConfirmExecute( params: Record, + abortSignal: AbortSignal, ): Promise; } @@ -120,11 +121,14 @@ export class Turn { // The run method yields simpler events suitable for server logic async *run( req: PartListUnion, - signal?: AbortSignal, + signal: AbortSignal, ): AsyncGenerator { try { const responseStream = await this.chat.sendMessageStream({ message: req, + config: { + abortSignal: signal, + }, }); for await (const resp of responseStream) { @@ -150,6 +154,12 @@ export class Turn { } } } catch (error) { + if (signal.aborted) { + yield { type: GeminiEventType.UserCancelled }; + // Regular cancellation error, fail gracefully. + return; + } + const contextForReport = [...this.chat.getHistory(/*curated*/ true), req]; await reportError( error, diff --git a/packages/server/src/tools/edit.test.ts b/packages/server/src/tools/edit.test.ts index 88216d53..08d0860d 100644 --- a/packages/server/src/tools/edit.test.ts +++ b/packages/server/src/tools/edit.test.ts @@ -223,7 +223,9 @@ describe('EditTool', () => { old_string: 'old', new_string: 'new', }; - expect(await tool.shouldConfirmExecute(params)).toBe(false); + expect( + await tool.shouldConfirmExecute(params, new AbortController().signal), + ).toBe(false); }); it('should request confirmation for valid edit', async () => { @@ -235,7 +237,10 @@ describe('EditTool', () => { }; // ensureCorrectEdit will be called by shouldConfirmExecute mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 }); - const confirmation = await tool.shouldConfirmExecute(params); + const confirmation = await tool.shouldConfirmExecute( + params, + new AbortController().signal, + ); expect(confirmation).toEqual( expect.objectContaining({ title: `Confirm Edit: ${testFile}`, @@ -253,7 +258,9 @@ describe('EditTool', () => { new_string: 'new', }; mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 }); - expect(await tool.shouldConfirmExecute(params)).toBe(false); + expect( + await tool.shouldConfirmExecute(params, new AbortController().signal), + ).toBe(false); }); it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => { @@ -264,7 +271,9 @@ describe('EditTool', () => { new_string: 'new', }; mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 }); - expect(await tool.shouldConfirmExecute(params)).toBe(false); + expect( + await tool.shouldConfirmExecute(params, new AbortController().signal), + ).toBe(false); }); it('should request confirmation for creating a new file (empty old_string)', async () => { @@ -279,7 +288,10 @@ describe('EditTool', () => { // as shouldConfirmExecute handles this for diff generation. // If it is called, it should return 0 occurrences for a new file. mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 }); - const confirmation = await tool.shouldConfirmExecute(params); + const confirmation = await tool.shouldConfirmExecute( + params, + new AbortController().signal, + ); expect(confirmation).toEqual( expect.objectContaining({ title: `Confirm Edit: ${newFileName}`, @@ -328,6 +340,7 @@ describe('EditTool', () => { const confirmation = (await tool.shouldConfirmExecute( params, + new AbortController().signal, )) as FileDiff; expect(mockCalled).toBe(true); // Check if the mock implementation was run diff --git a/packages/server/src/tools/edit.ts b/packages/server/src/tools/edit.ts index 781483ae..d85c89b0 100644 --- a/packages/server/src/tools/edit.ts +++ b/packages/server/src/tools/edit.ts @@ -174,7 +174,10 @@ Expectation for parameters: * @returns An object describing the potential edit outcome * @throws File system errors if reading the file fails unexpectedly (e.g., permissions) */ - private async calculateEdit(params: EditToolParams): Promise { + private async calculateEdit( + params: EditToolParams, + abortSignal: AbortSignal, + ): Promise { const expectedReplacements = 1; let currentContent: string | null = null; let fileExists = false; @@ -210,6 +213,7 @@ Expectation for parameters: currentContent, params, this.client, + abortSignal, ); finalOldString = correctedEdit.params.old_string; finalNewString = correctedEdit.params.new_string; @@ -262,6 +266,7 @@ Expectation for parameters: */ async shouldConfirmExecute( params: EditToolParams, + abortSignal: AbortSignal, ): Promise { if (this.config.getAlwaysSkipModificationConfirmation()) { return false; @@ -300,6 +305,7 @@ Expectation for parameters: currentContent, params, this.client, + abortSignal, ); finalOldString = correctedEdit.params.old_string; finalNewString = correctedEdit.params.new_string; @@ -376,7 +382,7 @@ Expectation for parameters: let editData: CalculatedEdit; try { - editData = await this.calculateEdit(params); + editData = await this.calculateEdit(params, _signal); } catch (error) { const errorMsg = error instanceof Error ? error.message : String(error); return { diff --git a/packages/server/src/tools/shell.ts b/packages/server/src/tools/shell.ts index 6ee36b8d..a708c93f 100644 --- a/packages/server/src/tools/shell.ts +++ b/packages/server/src/tools/shell.ts @@ -98,6 +98,7 @@ export class ShellTool extends BaseTool { async shouldConfirmExecute( params: ShellToolParams, + _abortSignal: AbortSignal, ): Promise { if (this.validateToolParams(params)) { return false; // skip confirmation, execute call will fail immediately diff --git a/packages/server/src/tools/tools.ts b/packages/server/src/tools/tools.ts index c57bbd39..8ec11bf0 100644 --- a/packages/server/src/tools/tools.ts +++ b/packages/server/src/tools/tools.ts @@ -57,6 +57,7 @@ export interface Tool< */ shouldConfirmExecute( params: TParams, + abortSignal: AbortSignal, ): Promise; /** @@ -137,6 +138,8 @@ export abstract class BaseTool< shouldConfirmExecute( // eslint-disable-next-line @typescript-eslint/no-unused-vars params: TParams, + // eslint-disable-next-line @typescript-eslint/no-unused-vars + abortSignal: AbortSignal, ): Promise { return Promise.resolve(false); } diff --git a/packages/server/src/tools/write-file.test.ts b/packages/server/src/tools/write-file.test.ts index 83e75a33..3fd97c9e 100644 --- a/packages/server/src/tools/write-file.test.ts +++ b/packages/server/src/tools/write-file.test.ts @@ -110,18 +110,32 @@ describe('WriteFileTool', () => { // Default mock implementations that return valid structures mockEnsureCorrectEdit.mockImplementation( async ( - currentContent: string, + _currentContent: string, params: EditToolParams, _client: GeminiClient, - ): Promise => - Promise.resolve({ + signal?: AbortSignal, // Make AbortSignal optional to match usage + ): Promise => { + if (signal?.aborted) { + return Promise.reject(new Error('Aborted')); + } + return Promise.resolve({ params: { ...params, new_string: params.new_string ?? '' }, occurrences: 1, - }), + }); + }, ); mockEnsureCorrectFileContent.mockImplementation( - async (content: string, _client: GeminiClient): Promise => - Promise.resolve(content ?? ''), + async ( + content: string, + _client: GeminiClient, + signal?: AbortSignal, + ): Promise => { + // Make AbortSignal optional + if (signal?.aborted) { + return Promise.reject(new Error('Aborted')); + } + return Promise.resolve(content ?? ''); + }, ); }); @@ -181,6 +195,7 @@ describe('WriteFileTool', () => { const filePath = path.join(rootDir, 'new_corrected_file.txt'); const proposedContent = 'Proposed new content.'; const correctedContent = 'Corrected new content.'; + const abortSignal = new AbortController().signal; // Ensure the mock is set for this specific test case if needed, or rely on beforeEach mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); @@ -188,11 +203,13 @@ describe('WriteFileTool', () => { const result = await tool._getCorrectedFileContent( filePath, proposedContent, + abortSignal, ); expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith( proposedContent, mockGeminiClientInstance, + abortSignal, ); expect(mockEnsureCorrectEdit).not.toHaveBeenCalled(); expect(result.correctedContent).toBe(correctedContent); @@ -206,6 +223,7 @@ describe('WriteFileTool', () => { const originalContent = 'Original existing content.'; const proposedContent = 'Proposed replacement content.'; const correctedProposedContent = 'Corrected replacement content.'; + const abortSignal = new AbortController().signal; fs.writeFileSync(filePath, originalContent, 'utf8'); // Ensure this mock is active and returns the correct structure @@ -222,6 +240,7 @@ describe('WriteFileTool', () => { const result = await tool._getCorrectedFileContent( filePath, proposedContent, + abortSignal, ); expect(mockEnsureCorrectEdit).toHaveBeenCalledWith( @@ -232,6 +251,7 @@ describe('WriteFileTool', () => { file_path: filePath, }, mockGeminiClientInstance, + abortSignal, ); expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled(); expect(result.correctedContent).toBe(correctedProposedContent); @@ -243,6 +263,7 @@ describe('WriteFileTool', () => { it('should return error if reading an existing file fails (e.g. permissions)', async () => { const filePath = path.join(rootDir, 'unreadable_file.txt'); const proposedContent = 'some content'; + const abortSignal = new AbortController().signal; fs.writeFileSync(filePath, 'content', { mode: 0o000 }); const readError = new Error('Permission denied'); @@ -255,6 +276,7 @@ describe('WriteFileTool', () => { const result = await tool._getCorrectedFileContent( filePath, proposedContent, + abortSignal, ); expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8'); @@ -274,16 +296,17 @@ describe('WriteFileTool', () => { }); describe('shouldConfirmExecute', () => { + const abortSignal = new AbortController().signal; it('should return false if params are invalid (relative path)', async () => { const params = { file_path: 'relative.txt', content: 'test' }; - const confirmation = await tool.shouldConfirmExecute(params); + const confirmation = await tool.shouldConfirmExecute(params, abortSignal); expect(confirmation).toBe(false); }); it('should return false if params are invalid (outside root)', async () => { const outsidePath = path.resolve(tempDir, 'outside-root.txt'); const params = { file_path: outsidePath, content: 'test' }; - const confirmation = await tool.shouldConfirmExecute(params); + const confirmation = await tool.shouldConfirmExecute(params, abortSignal); expect(confirmation).toBe(false); }); @@ -298,7 +321,7 @@ describe('WriteFileTool', () => { throw readError; }); - const confirmation = await tool.shouldConfirmExecute(params); + const confirmation = await tool.shouldConfirmExecute(params, abortSignal); expect(confirmation).toBe(false); vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync); @@ -314,11 +337,13 @@ describe('WriteFileTool', () => { const params = { file_path: filePath, content: proposedContent }; const confirmation = (await tool.shouldConfirmExecute( params, + abortSignal, )) as ToolEditConfirmationDetails; expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith( proposedContent, mockGeminiClientInstance, + abortSignal, ); expect(confirmation).toEqual( expect.objectContaining({ @@ -343,7 +368,6 @@ describe('WriteFileTool', () => { 'Corrected replacement for confirmation.'; fs.writeFileSync(filePath, originalContent, 'utf8'); - // Ensure this mock is active and returns the correct structure mockEnsureCorrectEdit.mockResolvedValue({ params: { file_path: filePath, @@ -356,6 +380,7 @@ describe('WriteFileTool', () => { const params = { file_path: filePath, content: proposedContent }; const confirmation = (await tool.shouldConfirmExecute( params, + abortSignal, )) as ToolEditConfirmationDetails; expect(mockEnsureCorrectEdit).toHaveBeenCalledWith( @@ -366,6 +391,7 @@ describe('WriteFileTool', () => { file_path: filePath, }, mockGeminiClientInstance, + abortSignal, ); expect(confirmation).toEqual( expect.objectContaining({ @@ -381,9 +407,10 @@ describe('WriteFileTool', () => { }); describe('execute', () => { + const abortSignal = new AbortController().signal; it('should return error if params are invalid (relative path)', async () => { const params = { file_path: 'relative.txt', content: 'test' }; - const result = await tool.execute(params, new AbortController().signal); + const result = await tool.execute(params, abortSignal); expect(result.llmContent).toMatch(/Error: Invalid parameters provided/); expect(result.returnDisplay).toMatch(/Error: File path must be absolute/); }); @@ -391,7 +418,7 @@ describe('WriteFileTool', () => { it('should return error if params are invalid (path outside root)', async () => { const outsidePath = path.resolve(tempDir, 'outside-root.txt'); const params = { file_path: outsidePath, content: 'test' }; - const result = await tool.execute(params, new AbortController().signal); + const result = await tool.execute(params, abortSignal); expect(result.llmContent).toMatch(/Error: Invalid parameters provided/); expect(result.returnDisplay).toMatch( /Error: File path must be within the root directory/, @@ -409,7 +436,7 @@ describe('WriteFileTool', () => { throw readError; }); - const result = await tool.execute(params, new AbortController().signal); + const result = await tool.execute(params, abortSignal); expect(result.llmContent).toMatch(/Error checking existing file/); expect(result.returnDisplay).toMatch( /Error checking existing file: Simulated read error for execute/, @@ -427,16 +454,20 @@ describe('WriteFileTool', () => { const params = { file_path: filePath, content: proposedContent }; - const confirmDetails = await tool.shouldConfirmExecute(params); + const confirmDetails = await tool.shouldConfirmExecute( + params, + abortSignal, + ); if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } - const result = await tool.execute(params, new AbortController().signal); + const result = await tool.execute(params, abortSignal); expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith( proposedContent, mockGeminiClientInstance, + abortSignal, ); expect(result.llmContent).toMatch( /Successfully created and wrote to new file/, @@ -477,12 +508,15 @@ describe('WriteFileTool', () => { const params = { file_path: filePath, content: proposedContent }; - const confirmDetails = await tool.shouldConfirmExecute(params); + const confirmDetails = await tool.shouldConfirmExecute( + params, + abortSignal, + ); if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } - const result = await tool.execute(params, new AbortController().signal); + const result = await tool.execute(params, abortSignal); expect(mockEnsureCorrectEdit).toHaveBeenCalledWith( initialContent, @@ -492,6 +526,7 @@ describe('WriteFileTool', () => { file_path: filePath, }, mockGeminiClientInstance, + abortSignal, ); expect(result.llmContent).toMatch(/Successfully overwrote file/); expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent); @@ -513,12 +548,15 @@ describe('WriteFileTool', () => { const params = { file_path: filePath, content }; // Simulate confirmation if your logic requires it before execute, or remove if not needed for this path - const confirmDetails = await tool.shouldConfirmExecute(params); + const confirmDetails = await tool.shouldConfirmExecute( + params, + abortSignal, + ); if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); } - await tool.execute(params, new AbortController().signal); + await tool.execute(params, abortSignal); expect(fs.existsSync(dirPath)).toBe(true); expect(fs.statSync(dirPath).isDirectory()).toBe(true); diff --git a/packages/server/src/tools/write-file.ts b/packages/server/src/tools/write-file.ts index 60646cc2..2285c819 100644 --- a/packages/server/src/tools/write-file.ts +++ b/packages/server/src/tools/write-file.ts @@ -141,6 +141,7 @@ export class WriteFileTool extends BaseTool { */ async shouldConfirmExecute( params: WriteFileToolParams, + abortSignal: AbortSignal, ): Promise { if (this.config.getAlwaysSkipModificationConfirmation()) { return false; @@ -154,6 +155,7 @@ export class WriteFileTool extends BaseTool { const correctedContentResult = await this._getCorrectedFileContent( params.file_path, params.content, + abortSignal, ); if (correctedContentResult.error) { @@ -193,7 +195,7 @@ export class WriteFileTool extends BaseTool { async execute( params: WriteFileToolParams, - _signal: AbortSignal, + abortSignal: AbortSignal, ): Promise { const validationError = this.validateToolParams(params); if (validationError) { @@ -206,6 +208,7 @@ export class WriteFileTool extends BaseTool { const correctedContentResult = await this._getCorrectedFileContent( params.file_path, params.content, + abortSignal, ); if (correctedContentResult.error) { @@ -277,6 +280,7 @@ export class WriteFileTool extends BaseTool { private async _getCorrectedFileContent( filePath: string, proposedContent: string, + abortSignal: AbortSignal, ): Promise { let originalContent = ''; let fileExists = false; @@ -316,6 +320,7 @@ export class WriteFileTool extends BaseTool { file_path: filePath, }, this.client, + abortSignal, ); correctedContent = correctedParams.new_string; } else { @@ -323,6 +328,7 @@ export class WriteFileTool extends BaseTool { correctedContent = await ensureCorrectFileContent( proposedContent, this.client, + abortSignal, ); } return { originalContent, correctedContent, fileExists }; diff --git a/packages/server/src/utils/editCorrector.test.ts b/packages/server/src/utils/editCorrector.test.ts index 27c9ffe8..7d6f5a53 100644 --- a/packages/server/src/utils/editCorrector.test.ts +++ b/packages/server/src/utils/editCorrector.test.ts @@ -132,6 +132,7 @@ describe('editCorrector', () => { let mockGeminiClientInstance: Mocked; let mockToolRegistry: Mocked; let mockConfigInstance: Config; + const abortSignal = new AbortController().signal; beforeEach(() => { mockToolRegistry = new ToolRegistry({} as Config) as Mocked; @@ -187,12 +188,18 @@ describe('editCorrector', () => { callCount = 0; mockResponses.length = 0; - mockGenerateJson = vi.fn().mockImplementation(() => { - const response = mockResponses[callCount]; - callCount++; - if (response === undefined) return Promise.resolve({}); - return Promise.resolve(response); - }); + mockGenerateJson = vi + .fn() + .mockImplementation((_contents, _schema, signal) => { + // Check if the signal is aborted. If so, throw an error or return a specific response. + if (signal && signal.aborted) { + return Promise.reject(new Error('Aborted')); // Or some other specific error/response + } + const response = mockResponses[callCount]; + callCount++; + if (response === undefined) return Promise.resolve({}); + return Promise.resolve(response); + }); mockStartChat = vi.fn(); mockSendMessageStream = vi.fn(); @@ -217,6 +224,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with "this"'); @@ -234,6 +242,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(result.params.new_string).toBe('replace with this'); @@ -254,6 +263,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with "this"'); @@ -271,6 +281,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(result.params.new_string).toBe('replace with this'); @@ -292,6 +303,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with "this"'); @@ -309,6 +321,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(result.params.new_string).toBe('replace with this'); @@ -329,6 +342,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with foobar'); @@ -351,6 +365,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe(llmNewString); @@ -372,6 +387,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(2); expect(result.params.new_string).toBe(llmNewString); @@ -391,6 +407,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe('replace with "this"'); @@ -412,6 +429,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM); @@ -432,6 +450,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(result.params).toEqual(originalParams); @@ -449,6 +468,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(result.params).toEqual(originalParams); @@ -471,6 +491,7 @@ describe('editCorrector', () => { currentContent, originalParams, mockGeminiClientInstance, + abortSignal, ); expect(mockGenerateJson).toHaveBeenCalledTimes(2); expect(result.params.old_string).toBe(currentContent); diff --git a/packages/server/src/utils/editCorrector.ts b/packages/server/src/utils/editCorrector.ts index 92551478..78663954 100644 --- a/packages/server/src/utils/editCorrector.ts +++ b/packages/server/src/utils/editCorrector.ts @@ -63,6 +63,7 @@ export async function ensureCorrectEdit( currentContent: string, originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\' client: GeminiClient, + abortSignal: AbortSignal, ): Promise { const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`; const cachedResult = editCorrectionCache.get(cacheKey); @@ -84,6 +85,7 @@ export async function ensureCorrectEdit( client, finalOldString, originalParams.new_string, + abortSignal, ); } } else if (occurrences > 1) { @@ -108,6 +110,7 @@ export async function ensureCorrectEdit( originalParams.old_string, // original old unescapedOldStringAttempt, // corrected old originalParams.new_string, // original new (which is potentially escaped) + abortSignal, ); } } else if (occurrences === 0) { @@ -115,6 +118,7 @@ export async function ensureCorrectEdit( client, currentContent, unescapedOldStringAttempt, + abortSignal, ); const llmOldOccurrences = countOccurrences( currentContent, @@ -134,6 +138,7 @@ export async function ensureCorrectEdit( originalParams.old_string, // original old llmCorrectedOldString, // corrected old baseNewStringForLLMCorrection, // base new for correction + abortSignal, ); } } else { @@ -180,6 +185,7 @@ export async function ensureCorrectEdit( export async function ensureCorrectFileContent( content: string, client: GeminiClient, + abortSignal: AbortSignal, ): Promise { const cachedResult = fileContentCorrectionCache.get(content); if (cachedResult) { @@ -193,7 +199,11 @@ export async function ensureCorrectFileContent( return content; } - const correctedContent = await correctStringEscaping(content, client); + const correctedContent = await correctStringEscaping( + content, + client, + abortSignal, + ); fileContentCorrectionCache.set(content, correctedContent); return correctedContent; } @@ -215,6 +225,7 @@ export async function correctOldStringMismatch( geminiClient: GeminiClient, fileContent: string, problematicSnippet: string, + abortSignal: AbortSignal, ): Promise { const prompt = ` Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped. @@ -243,6 +254,7 @@ Return ONLY the corrected target snippet in the specified JSON format with the k const result = await geminiClient.generateJson( contents, OLD_STRING_CORRECTION_SCHEMA, + abortSignal, EditModel, EditConfig, ); @@ -257,10 +269,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k return problematicSnippet; } } catch (error) { + if (abortSignal.aborted) { + throw error; + } + console.error( 'Error during LLM call for old string snippet correction:', error, ); + return problematicSnippet; } } @@ -286,6 +303,7 @@ export async function correctNewString( originalOldString: string, correctedOldString: string, originalNewString: string, + abortSignal: AbortSignal, ): Promise { if (originalOldString === correctedOldString) { return originalNewString; @@ -324,6 +342,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr const result = await geminiClient.generateJson( contents, NEW_STRING_CORRECTION_SCHEMA, + abortSignal, EditModel, EditConfig, ); @@ -338,6 +357,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr return originalNewString; } } catch (error) { + if (abortSignal.aborted) { + throw error; + } + console.error('Error during LLM call for new_string correction:', error); return originalNewString; } @@ -359,6 +382,7 @@ export async function correctNewStringEscaping( geminiClient: GeminiClient, oldString: string, potentiallyProblematicNewString: string, + abortSignal: AbortSignal, ): Promise { const prompt = ` Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello"). @@ -387,6 +411,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr const result = await geminiClient.generateJson( contents, CORRECT_NEW_STRING_ESCAPING_SCHEMA, + abortSignal, EditModel, EditConfig, ); @@ -401,6 +426,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr return potentiallyProblematicNewString; } } catch (error) { + if (abortSignal.aborted) { + throw error; + } + console.error( 'Error during LLM call for new_string escaping correction:', error, @@ -424,6 +453,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = { export async function correctStringEscaping( potentiallyProblematicString: string, client: GeminiClient, + abortSignal: AbortSignal, ): Promise { const prompt = ` Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello"). @@ -447,6 +477,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr const result = await client.generateJson( contents, CORRECT_STRING_ESCAPING_SCHEMA, + abortSignal, EditModel, EditConfig, ); @@ -461,6 +492,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr return potentiallyProblematicString; } } catch (error) { + if (abortSignal.aborted) { + throw error; + } + console.error( 'Error during LLM call for string escaping correction:', error, diff --git a/packages/server/src/utils/nextSpeakerChecker.test.ts b/packages/server/src/utils/nextSpeakerChecker.test.ts index 1d87bffb..872e00f6 100644 --- a/packages/server/src/utils/nextSpeakerChecker.test.ts +++ b/packages/server/src/utils/nextSpeakerChecker.test.ts @@ -44,6 +44,7 @@ describe('checkNextSpeaker', () => { let chatInstance: GeminiChat; let mockGeminiClient: GeminiClient; let MockConfig: Mock; + const abortSignal = new AbortController().signal; beforeEach(() => { MockConfig = vi.mocked(Config); @@ -71,7 +72,7 @@ describe('checkNextSpeaker', () => { mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel 'gemini-pro', // model name - {}, // config + {}, [], // initial history ); @@ -85,7 +86,11 @@ describe('checkNextSpeaker', () => { it('should return null if history is empty', async () => { (chatInstance.getHistory as Mock).mockReturnValue([]); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); }); @@ -94,7 +99,11 @@ describe('checkNextSpeaker', () => { (chatInstance.getHistory as Mock).mockReturnValue([ { role: 'user', parts: [{ text: 'Hello' }] }, ] as Content[]); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); }); @@ -109,7 +118,11 @@ describe('checkNextSpeaker', () => { }; (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toEqual(mockApiResponse); expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1); }); @@ -124,7 +137,11 @@ describe('checkNextSpeaker', () => { }; (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toEqual(mockApiResponse); }); @@ -138,7 +155,11 @@ describe('checkNextSpeaker', () => { }; (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toEqual(mockApiResponse); }); @@ -153,7 +174,11 @@ describe('checkNextSpeaker', () => { new Error('API Error'), ); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); consoleWarnSpy.mockRestore(); }); @@ -166,7 +191,11 @@ describe('checkNextSpeaker', () => { reasoning: 'This is incomplete.', } as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); }); @@ -179,7 +208,11 @@ describe('checkNextSpeaker', () => { next_speaker: 123, // Invalid type } as unknown as NextSpeakerResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); }); @@ -192,7 +225,11 @@ describe('checkNextSpeaker', () => { next_speaker: 'neither', // Invalid enum value } as unknown as NextSpeakerResponse); - const result = await checkNextSpeaker(chatInstance, mockGeminiClient); + const result = await checkNextSpeaker( + chatInstance, + mockGeminiClient, + abortSignal, + ); expect(result).toBeNull(); }); }); diff --git a/packages/server/src/utils/nextSpeakerChecker.ts b/packages/server/src/utils/nextSpeakerChecker.ts index fb00b39c..66fa4395 100644 --- a/packages/server/src/utils/nextSpeakerChecker.ts +++ b/packages/server/src/utils/nextSpeakerChecker.ts @@ -61,6 +61,7 @@ export interface NextSpeakerResponse { export async function checkNextSpeaker( chat: GeminiChat, geminiClient: GeminiClient, + abortSignal: AbortSignal, ): Promise { // We need to capture the curated history because there are many moments when the model will return invalid turns // that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides @@ -129,6 +130,7 @@ export async function checkNextSpeaker( const parsedResponse = (await geminiClient.generateJson( contents, RESPONSE_SCHEMA, + abortSignal, )) as unknown as NextSpeakerResponse; if (