feat: Allow cancellation of in-progress Gemini requests and pre-execution checks

- Implements cancellation for Gemini requests while they are actively being processed by the model.
- Extends cancellation support to the  logic within tools. This allows users to cancel operations during the phase where the system is determining if a tool execution requires user confirmation, which can include potentially long-running pre-flight checks or LLM-based corrections.
- Underlying LLM calls for edit corrections (within  and ) and next speaker checks can now also be cancelled.
- Previously, cancellation of the main request was not possible until text started streaming, and pre-execution checks were not cancellable.
- This change leverages the updated SDK's ability to accept an abort token and threads s throughout the request, tool execution, and pre-execution check lifecycle.

Fixes https://github.com/google-gemini/gemini-cli/issues/531
This commit is contained in:
Taylor Mullen 2025-05-27 23:40:25 -07:00 committed by N. Taylor Mullen
parent bfeaac8441
commit f2f2ecf9d8
16 changed files with 260 additions and 61 deletions

View File

@ -95,9 +95,11 @@ async function main() {
const geminiClient = new GeminiClient(config); const geminiClient = new GeminiClient(config);
const chat = await geminiClient.startChat(); const chat = await geminiClient.startChat();
try { try {
for await (const event of geminiClient.sendMessageStream(chat, [ for await (const event of geminiClient.sendMessageStream(
{ text: input }, chat,
])) { [{ text: input }],
new AbortController().signal,
)) {
if (event.type === 'content') { if (event.type === 'content') {
process.stdout.write(event.value); process.stdout.write(event.value);
} }

View File

@ -142,7 +142,10 @@ export function useToolScheduler(
const { request: r, tool } = initialCall; const { request: r, tool } = initialCall;
try { try {
const userApproval = await tool.shouldConfirmExecute(r.args); const userApproval = await tool.shouldConfirmExecute(
r.args,
abortController.signal,
);
if (userApproval) { if (userApproval) {
// Confirmation is needed. Update status to 'awaiting_approval'. // Confirmation is needed. Update status to 'awaiting_approval'.
setToolCalls( setToolCalls(
@ -183,7 +186,7 @@ export function useToolScheduler(
} }
}); });
}, },
[isRunning, setToolCalls, toolRegistry], [isRunning, setToolCalls, toolRegistry, abortController.signal],
); );
const cancel = useCallback( const cancel = useCallback(

View File

@ -157,7 +157,7 @@ export class GeminiClient {
async *sendMessageStream( async *sendMessageStream(
chat: GeminiChat, chat: GeminiChat,
request: PartListUnion, request: PartListUnion,
signal?: AbortSignal, signal: AbortSignal,
turns: number = this.MAX_TURNS, turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent> { ): AsyncGenerator<ServerGeminiStreamEvent> {
if (!turns) { if (!turns) {
@ -169,8 +169,8 @@ export class GeminiClient {
for await (const event of resultStream) { for await (const event of resultStream) {
yield event; yield event;
} }
if (!turn.pendingToolCalls.length) { if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
const nextSpeakerCheck = await checkNextSpeaker(chat, this); const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
if (nextSpeakerCheck?.next_speaker === 'model') { if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }]; const nextRequest = [{ text: 'Please continue.' }];
yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1); yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1);
@ -181,6 +181,7 @@ export class GeminiClient {
async generateJson( async generateJson(
contents: Content[], contents: Content[],
schema: SchemaUnion, schema: SchemaUnion,
abortSignal: AbortSignal,
model: string = 'gemini-2.0-flash', model: string = 'gemini-2.0-flash',
config: GenerateContentConfig = {}, config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> { ): Promise<Record<string, unknown>> {
@ -188,6 +189,7 @@ export class GeminiClient {
const userMemory = this.config.getUserMemory(); const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory); const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = { const requestConfig = {
abortSignal,
...this.generateContentConfig, ...this.generateContentConfig,
...config, ...config,
}; };
@ -232,6 +234,11 @@ export class GeminiClient {
); );
} }
} catch (error) { } catch (error) {
if (abortSignal.aborted) {
// Regular cancellation error, fail normally
throw error;
}
// Avoid double reporting for the empty response case handled above // Avoid double reporting for the empty response case handled above
if ( if (
error instanceof Error && error instanceof Error &&

View File

@ -155,7 +155,7 @@ export class GeminiChat {
const responsePromise = this.modelsModule.generateContent({ const responsePromise = this.modelsModule.generateContent({
model: this.model, model: this.model,
contents: this.getHistory(true).concat(userContent), contents: this.getHistory(true).concat(userContent),
config: params.config ?? this.config, config: { ...this.config, ...params.config },
}); });
this.sendPromise = (async () => { this.sendPromise = (async () => {
const response = await responsePromise; const response = await responsePromise;
@ -219,7 +219,7 @@ export class GeminiChat {
const streamResponse = this.modelsModule.generateContentStream({ const streamResponse = this.modelsModule.generateContentStream({
model: this.model, model: this.model,
contents: this.getHistory(true).concat(userContent), contents: this.getHistory(true).concat(userContent),
config: params.config ?? this.config, config: { ...this.config, ...params.config },
}); });
// Resolve the internal tracking of send completion promise - `sendPromise` // Resolve the internal tracking of send completion promise - `sendPromise`
// for both success and failure response. The actual failure is still // for both success and failure response. The actual failure is still

View File

@ -85,11 +85,17 @@ describe('Turn', () => {
const events = []; const events = [];
const reqParts: Part[] = [{ text: 'Hi' }]; const reqParts: Part[] = [{ text: 'Hi' }];
for await (const event of turn.run(reqParts)) { for await (const event of turn.run(
reqParts,
new AbortController().signal,
)) {
events.push(event); events.push(event);
} }
expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts }); expect(mockSendMessageStream).toHaveBeenCalledWith({
message: reqParts,
config: { abortSignal: expect.any(AbortSignal) },
});
expect(events).toEqual([ expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' world' }, { type: GeminiEventType.Content, value: ' world' },
@ -110,7 +116,10 @@ describe('Turn', () => {
const events = []; const events = [];
const reqParts: Part[] = [{ text: 'Use tools' }]; const reqParts: Part[] = [{ text: 'Use tools' }];
for await (const event of turn.run(reqParts)) { for await (const event of turn.run(
reqParts,
new AbortController().signal,
)) {
events.push(event); events.push(event);
} }
@ -179,7 +188,10 @@ describe('Turn', () => {
mockGetHistory.mockReturnValue(historyContent); mockGetHistory.mockReturnValue(historyContent);
const events = []; const events = [];
for await (const event of turn.run(reqParts)) { for await (const event of turn.run(
reqParts,
new AbortController().signal,
)) {
events.push(event); events.push(event);
} }
@ -210,7 +222,10 @@ describe('Turn', () => {
const events = []; const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
for await (const event of turn.run(reqParts)) { for await (const event of turn.run(
reqParts,
new AbortController().signal,
)) {
events.push(event); events.push(event);
} }
@ -261,7 +276,7 @@ describe('Turn', () => {
})(); })();
mockSendMessageStream.mockResolvedValue(mockResponseStream); mockSendMessageStream.mockResolvedValue(mockResponseStream);
const reqParts: Part[] = [{ text: 'Hi' }]; const reqParts: Part[] = [{ text: 'Hi' }];
for await (const _ of turn.run(reqParts)) { for await (const _ of turn.run(reqParts, new AbortController().signal)) {
// consume stream // consume stream
} }
expect(turn.getDebugResponses()).toEqual([resp1, resp2]); expect(turn.getDebugResponses()).toEqual([resp1, resp2]);

View File

@ -32,6 +32,7 @@ export interface ServerTool {
): Promise<ToolResult>; ): Promise<ToolResult>;
shouldConfirmExecute( shouldConfirmExecute(
params: Record<string, unknown>, params: Record<string, unknown>,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false>; ): Promise<ToolCallConfirmationDetails | false>;
} }
@ -120,11 +121,14 @@ export class Turn {
// The run method yields simpler events suitable for server logic // The run method yields simpler events suitable for server logic
async *run( async *run(
req: PartListUnion, req: PartListUnion,
signal?: AbortSignal, signal: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent> { ): AsyncGenerator<ServerGeminiStreamEvent> {
try { try {
const responseStream = await this.chat.sendMessageStream({ const responseStream = await this.chat.sendMessageStream({
message: req, message: req,
config: {
abortSignal: signal,
},
}); });
for await (const resp of responseStream) { for await (const resp of responseStream) {
@ -150,6 +154,12 @@ export class Turn {
} }
} }
} catch (error) { } catch (error) {
if (signal.aborted) {
yield { type: GeminiEventType.UserCancelled };
// Regular cancellation error, fail gracefully.
return;
}
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req]; const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
await reportError( await reportError(
error, error,

View File

@ -223,7 +223,9 @@ describe('EditTool', () => {
old_string: 'old', old_string: 'old',
new_string: 'new', new_string: 'new',
}; };
expect(await tool.shouldConfirmExecute(params)).toBe(false); expect(
await tool.shouldConfirmExecute(params, new AbortController().signal),
).toBe(false);
}); });
it('should request confirmation for valid edit', async () => { it('should request confirmation for valid edit', async () => {
@ -235,7 +237,10 @@ describe('EditTool', () => {
}; };
// ensureCorrectEdit will be called by shouldConfirmExecute // ensureCorrectEdit will be called by shouldConfirmExecute
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 }); mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 });
const confirmation = await tool.shouldConfirmExecute(params); const confirmation = await tool.shouldConfirmExecute(
params,
new AbortController().signal,
);
expect(confirmation).toEqual( expect(confirmation).toEqual(
expect.objectContaining({ expect.objectContaining({
title: `Confirm Edit: ${testFile}`, title: `Confirm Edit: ${testFile}`,
@ -253,7 +258,9 @@ describe('EditTool', () => {
new_string: 'new', new_string: 'new',
}; };
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 }); mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
expect(await tool.shouldConfirmExecute(params)).toBe(false); expect(
await tool.shouldConfirmExecute(params, new AbortController().signal),
).toBe(false);
}); });
it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => { it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => {
@ -264,7 +271,9 @@ describe('EditTool', () => {
new_string: 'new', new_string: 'new',
}; };
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 }); mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 });
expect(await tool.shouldConfirmExecute(params)).toBe(false); expect(
await tool.shouldConfirmExecute(params, new AbortController().signal),
).toBe(false);
}); });
it('should request confirmation for creating a new file (empty old_string)', async () => { it('should request confirmation for creating a new file (empty old_string)', async () => {
@ -279,7 +288,10 @@ describe('EditTool', () => {
// as shouldConfirmExecute handles this for diff generation. // as shouldConfirmExecute handles this for diff generation.
// If it is called, it should return 0 occurrences for a new file. // If it is called, it should return 0 occurrences for a new file.
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 }); mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
const confirmation = await tool.shouldConfirmExecute(params); const confirmation = await tool.shouldConfirmExecute(
params,
new AbortController().signal,
);
expect(confirmation).toEqual( expect(confirmation).toEqual(
expect.objectContaining({ expect.objectContaining({
title: `Confirm Edit: ${newFileName}`, title: `Confirm Edit: ${newFileName}`,
@ -328,6 +340,7 @@ describe('EditTool', () => {
const confirmation = (await tool.shouldConfirmExecute( const confirmation = (await tool.shouldConfirmExecute(
params, params,
new AbortController().signal,
)) as FileDiff; )) as FileDiff;
expect(mockCalled).toBe(true); // Check if the mock implementation was run expect(mockCalled).toBe(true); // Check if the mock implementation was run

View File

@ -174,7 +174,10 @@ Expectation for parameters:
* @returns An object describing the potential edit outcome * @returns An object describing the potential edit outcome
* @throws File system errors if reading the file fails unexpectedly (e.g., permissions) * @throws File system errors if reading the file fails unexpectedly (e.g., permissions)
*/ */
private async calculateEdit(params: EditToolParams): Promise<CalculatedEdit> { private async calculateEdit(
params: EditToolParams,
abortSignal: AbortSignal,
): Promise<CalculatedEdit> {
const expectedReplacements = 1; const expectedReplacements = 1;
let currentContent: string | null = null; let currentContent: string | null = null;
let fileExists = false; let fileExists = false;
@ -210,6 +213,7 @@ Expectation for parameters:
currentContent, currentContent,
params, params,
this.client, this.client,
abortSignal,
); );
finalOldString = correctedEdit.params.old_string; finalOldString = correctedEdit.params.old_string;
finalNewString = correctedEdit.params.new_string; finalNewString = correctedEdit.params.new_string;
@ -262,6 +266,7 @@ Expectation for parameters:
*/ */
async shouldConfirmExecute( async shouldConfirmExecute(
params: EditToolParams, params: EditToolParams,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> { ): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getAlwaysSkipModificationConfirmation()) { if (this.config.getAlwaysSkipModificationConfirmation()) {
return false; return false;
@ -300,6 +305,7 @@ Expectation for parameters:
currentContent, currentContent,
params, params,
this.client, this.client,
abortSignal,
); );
finalOldString = correctedEdit.params.old_string; finalOldString = correctedEdit.params.old_string;
finalNewString = correctedEdit.params.new_string; finalNewString = correctedEdit.params.new_string;
@ -376,7 +382,7 @@ Expectation for parameters:
let editData: CalculatedEdit; let editData: CalculatedEdit;
try { try {
editData = await this.calculateEdit(params); editData = await this.calculateEdit(params, _signal);
} catch (error) { } catch (error) {
const errorMsg = error instanceof Error ? error.message : String(error); const errorMsg = error instanceof Error ? error.message : String(error);
return { return {

View File

@ -98,6 +98,7 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
async shouldConfirmExecute( async shouldConfirmExecute(
params: ShellToolParams, params: ShellToolParams,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> { ): Promise<ToolCallConfirmationDetails | false> {
if (this.validateToolParams(params)) { if (this.validateToolParams(params)) {
return false; // skip confirmation, execute call will fail immediately return false; // skip confirmation, execute call will fail immediately

View File

@ -57,6 +57,7 @@ export interface Tool<
*/ */
shouldConfirmExecute( shouldConfirmExecute(
params: TParams, params: TParams,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false>; ): Promise<ToolCallConfirmationDetails | false>;
/** /**
@ -137,6 +138,8 @@ export abstract class BaseTool<
shouldConfirmExecute( shouldConfirmExecute(
// eslint-disable-next-line @typescript-eslint/no-unused-vars // eslint-disable-next-line @typescript-eslint/no-unused-vars
params: TParams, params: TParams,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> { ): Promise<ToolCallConfirmationDetails | false> {
return Promise.resolve(false); return Promise.resolve(false);
} }

View File

@ -110,18 +110,32 @@ describe('WriteFileTool', () => {
// Default mock implementations that return valid structures // Default mock implementations that return valid structures
mockEnsureCorrectEdit.mockImplementation( mockEnsureCorrectEdit.mockImplementation(
async ( async (
currentContent: string, _currentContent: string,
params: EditToolParams, params: EditToolParams,
_client: GeminiClient, _client: GeminiClient,
): Promise<CorrectedEditResult> => signal?: AbortSignal, // Make AbortSignal optional to match usage
Promise.resolve({ ): Promise<CorrectedEditResult> => {
if (signal?.aborted) {
return Promise.reject(new Error('Aborted'));
}
return Promise.resolve({
params: { ...params, new_string: params.new_string ?? '' }, params: { ...params, new_string: params.new_string ?? '' },
occurrences: 1, occurrences: 1,
}), });
},
); );
mockEnsureCorrectFileContent.mockImplementation( mockEnsureCorrectFileContent.mockImplementation(
async (content: string, _client: GeminiClient): Promise<string> => async (
Promise.resolve(content ?? ''), content: string,
_client: GeminiClient,
signal?: AbortSignal,
): Promise<string> => {
// Make AbortSignal optional
if (signal?.aborted) {
return Promise.reject(new Error('Aborted'));
}
return Promise.resolve(content ?? '');
},
); );
}); });
@ -181,6 +195,7 @@ describe('WriteFileTool', () => {
const filePath = path.join(rootDir, 'new_corrected_file.txt'); const filePath = path.join(rootDir, 'new_corrected_file.txt');
const proposedContent = 'Proposed new content.'; const proposedContent = 'Proposed new content.';
const correctedContent = 'Corrected new content.'; const correctedContent = 'Corrected new content.';
const abortSignal = new AbortController().signal;
// Ensure the mock is set for this specific test case if needed, or rely on beforeEach // Ensure the mock is set for this specific test case if needed, or rely on beforeEach
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent); mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
@ -188,11 +203,13 @@ describe('WriteFileTool', () => {
const result = await tool._getCorrectedFileContent( const result = await tool._getCorrectedFileContent(
filePath, filePath,
proposedContent, proposedContent,
abortSignal,
); );
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith( expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
proposedContent, proposedContent,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled(); expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
expect(result.correctedContent).toBe(correctedContent); expect(result.correctedContent).toBe(correctedContent);
@ -206,6 +223,7 @@ describe('WriteFileTool', () => {
const originalContent = 'Original existing content.'; const originalContent = 'Original existing content.';
const proposedContent = 'Proposed replacement content.'; const proposedContent = 'Proposed replacement content.';
const correctedProposedContent = 'Corrected replacement content.'; const correctedProposedContent = 'Corrected replacement content.';
const abortSignal = new AbortController().signal;
fs.writeFileSync(filePath, originalContent, 'utf8'); fs.writeFileSync(filePath, originalContent, 'utf8');
// Ensure this mock is active and returns the correct structure // Ensure this mock is active and returns the correct structure
@ -222,6 +240,7 @@ describe('WriteFileTool', () => {
const result = await tool._getCorrectedFileContent( const result = await tool._getCorrectedFileContent(
filePath, filePath,
proposedContent, proposedContent,
abortSignal,
); );
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith( expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
@ -232,6 +251,7 @@ describe('WriteFileTool', () => {
file_path: filePath, file_path: filePath,
}, },
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled(); expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
expect(result.correctedContent).toBe(correctedProposedContent); expect(result.correctedContent).toBe(correctedProposedContent);
@ -243,6 +263,7 @@ describe('WriteFileTool', () => {
it('should return error if reading an existing file fails (e.g. permissions)', async () => { it('should return error if reading an existing file fails (e.g. permissions)', async () => {
const filePath = path.join(rootDir, 'unreadable_file.txt'); const filePath = path.join(rootDir, 'unreadable_file.txt');
const proposedContent = 'some content'; const proposedContent = 'some content';
const abortSignal = new AbortController().signal;
fs.writeFileSync(filePath, 'content', { mode: 0o000 }); fs.writeFileSync(filePath, 'content', { mode: 0o000 });
const readError = new Error('Permission denied'); const readError = new Error('Permission denied');
@ -255,6 +276,7 @@ describe('WriteFileTool', () => {
const result = await tool._getCorrectedFileContent( const result = await tool._getCorrectedFileContent(
filePath, filePath,
proposedContent, proposedContent,
abortSignal,
); );
expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8'); expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8');
@ -274,16 +296,17 @@ describe('WriteFileTool', () => {
}); });
describe('shouldConfirmExecute', () => { describe('shouldConfirmExecute', () => {
const abortSignal = new AbortController().signal;
it('should return false if params are invalid (relative path)', async () => { it('should return false if params are invalid (relative path)', async () => {
const params = { file_path: 'relative.txt', content: 'test' }; const params = { file_path: 'relative.txt', content: 'test' };
const confirmation = await tool.shouldConfirmExecute(params); const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
expect(confirmation).toBe(false); expect(confirmation).toBe(false);
}); });
it('should return false if params are invalid (outside root)', async () => { it('should return false if params are invalid (outside root)', async () => {
const outsidePath = path.resolve(tempDir, 'outside-root.txt'); const outsidePath = path.resolve(tempDir, 'outside-root.txt');
const params = { file_path: outsidePath, content: 'test' }; const params = { file_path: outsidePath, content: 'test' };
const confirmation = await tool.shouldConfirmExecute(params); const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
expect(confirmation).toBe(false); expect(confirmation).toBe(false);
}); });
@ -298,7 +321,7 @@ describe('WriteFileTool', () => {
throw readError; throw readError;
}); });
const confirmation = await tool.shouldConfirmExecute(params); const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
expect(confirmation).toBe(false); expect(confirmation).toBe(false);
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync); vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
@ -314,11 +337,13 @@ describe('WriteFileTool', () => {
const params = { file_path: filePath, content: proposedContent }; const params = { file_path: filePath, content: proposedContent };
const confirmation = (await tool.shouldConfirmExecute( const confirmation = (await tool.shouldConfirmExecute(
params, params,
abortSignal,
)) as ToolEditConfirmationDetails; )) as ToolEditConfirmationDetails;
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith( expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
proposedContent, proposedContent,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(confirmation).toEqual( expect(confirmation).toEqual(
expect.objectContaining({ expect.objectContaining({
@ -343,7 +368,6 @@ describe('WriteFileTool', () => {
'Corrected replacement for confirmation.'; 'Corrected replacement for confirmation.';
fs.writeFileSync(filePath, originalContent, 'utf8'); fs.writeFileSync(filePath, originalContent, 'utf8');
// Ensure this mock is active and returns the correct structure
mockEnsureCorrectEdit.mockResolvedValue({ mockEnsureCorrectEdit.mockResolvedValue({
params: { params: {
file_path: filePath, file_path: filePath,
@ -356,6 +380,7 @@ describe('WriteFileTool', () => {
const params = { file_path: filePath, content: proposedContent }; const params = { file_path: filePath, content: proposedContent };
const confirmation = (await tool.shouldConfirmExecute( const confirmation = (await tool.shouldConfirmExecute(
params, params,
abortSignal,
)) as ToolEditConfirmationDetails; )) as ToolEditConfirmationDetails;
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith( expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
@ -366,6 +391,7 @@ describe('WriteFileTool', () => {
file_path: filePath, file_path: filePath,
}, },
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(confirmation).toEqual( expect(confirmation).toEqual(
expect.objectContaining({ expect.objectContaining({
@ -381,9 +407,10 @@ describe('WriteFileTool', () => {
}); });
describe('execute', () => { describe('execute', () => {
const abortSignal = new AbortController().signal;
it('should return error if params are invalid (relative path)', async () => { it('should return error if params are invalid (relative path)', async () => {
const params = { file_path: 'relative.txt', content: 'test' }; const params = { file_path: 'relative.txt', content: 'test' };
const result = await tool.execute(params, new AbortController().signal); const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/); expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
expect(result.returnDisplay).toMatch(/Error: File path must be absolute/); expect(result.returnDisplay).toMatch(/Error: File path must be absolute/);
}); });
@ -391,7 +418,7 @@ describe('WriteFileTool', () => {
it('should return error if params are invalid (path outside root)', async () => { it('should return error if params are invalid (path outside root)', async () => {
const outsidePath = path.resolve(tempDir, 'outside-root.txt'); const outsidePath = path.resolve(tempDir, 'outside-root.txt');
const params = { file_path: outsidePath, content: 'test' }; const params = { file_path: outsidePath, content: 'test' };
const result = await tool.execute(params, new AbortController().signal); const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/); expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
expect(result.returnDisplay).toMatch( expect(result.returnDisplay).toMatch(
/Error: File path must be within the root directory/, /Error: File path must be within the root directory/,
@ -409,7 +436,7 @@ describe('WriteFileTool', () => {
throw readError; throw readError;
}); });
const result = await tool.execute(params, new AbortController().signal); const result = await tool.execute(params, abortSignal);
expect(result.llmContent).toMatch(/Error checking existing file/); expect(result.llmContent).toMatch(/Error checking existing file/);
expect(result.returnDisplay).toMatch( expect(result.returnDisplay).toMatch(
/Error checking existing file: Simulated read error for execute/, /Error checking existing file: Simulated read error for execute/,
@ -427,16 +454,20 @@ describe('WriteFileTool', () => {
const params = { file_path: filePath, content: proposedContent }; const params = { file_path: filePath, content: proposedContent };
const confirmDetails = await tool.shouldConfirmExecute(params); const confirmDetails = await tool.shouldConfirmExecute(
params,
abortSignal,
);
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
} }
const result = await tool.execute(params, new AbortController().signal); const result = await tool.execute(params, abortSignal);
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith( expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
proposedContent, proposedContent,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(result.llmContent).toMatch( expect(result.llmContent).toMatch(
/Successfully created and wrote to new file/, /Successfully created and wrote to new file/,
@ -477,12 +508,15 @@ describe('WriteFileTool', () => {
const params = { file_path: filePath, content: proposedContent }; const params = { file_path: filePath, content: proposedContent };
const confirmDetails = await tool.shouldConfirmExecute(params); const confirmDetails = await tool.shouldConfirmExecute(
params,
abortSignal,
);
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
} }
const result = await tool.execute(params, new AbortController().signal); const result = await tool.execute(params, abortSignal);
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith( expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
initialContent, initialContent,
@ -492,6 +526,7 @@ describe('WriteFileTool', () => {
file_path: filePath, file_path: filePath,
}, },
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(result.llmContent).toMatch(/Successfully overwrote file/); expect(result.llmContent).toMatch(/Successfully overwrote file/);
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent); expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent);
@ -513,12 +548,15 @@ describe('WriteFileTool', () => {
const params = { file_path: filePath, content }; const params = { file_path: filePath, content };
// Simulate confirmation if your logic requires it before execute, or remove if not needed for this path // Simulate confirmation if your logic requires it before execute, or remove if not needed for this path
const confirmDetails = await tool.shouldConfirmExecute(params); const confirmDetails = await tool.shouldConfirmExecute(
params,
abortSignal,
);
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) { if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce); await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
} }
await tool.execute(params, new AbortController().signal); await tool.execute(params, abortSignal);
expect(fs.existsSync(dirPath)).toBe(true); expect(fs.existsSync(dirPath)).toBe(true);
expect(fs.statSync(dirPath).isDirectory()).toBe(true); expect(fs.statSync(dirPath).isDirectory()).toBe(true);

View File

@ -141,6 +141,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
*/ */
async shouldConfirmExecute( async shouldConfirmExecute(
params: WriteFileToolParams, params: WriteFileToolParams,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> { ): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getAlwaysSkipModificationConfirmation()) { if (this.config.getAlwaysSkipModificationConfirmation()) {
return false; return false;
@ -154,6 +155,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
const correctedContentResult = await this._getCorrectedFileContent( const correctedContentResult = await this._getCorrectedFileContent(
params.file_path, params.file_path,
params.content, params.content,
abortSignal,
); );
if (correctedContentResult.error) { if (correctedContentResult.error) {
@ -193,7 +195,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
async execute( async execute(
params: WriteFileToolParams, params: WriteFileToolParams,
_signal: AbortSignal, abortSignal: AbortSignal,
): Promise<ToolResult> { ): Promise<ToolResult> {
const validationError = this.validateToolParams(params); const validationError = this.validateToolParams(params);
if (validationError) { if (validationError) {
@ -206,6 +208,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
const correctedContentResult = await this._getCorrectedFileContent( const correctedContentResult = await this._getCorrectedFileContent(
params.file_path, params.file_path,
params.content, params.content,
abortSignal,
); );
if (correctedContentResult.error) { if (correctedContentResult.error) {
@ -277,6 +280,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
private async _getCorrectedFileContent( private async _getCorrectedFileContent(
filePath: string, filePath: string,
proposedContent: string, proposedContent: string,
abortSignal: AbortSignal,
): Promise<GetCorrectedFileContentResult> { ): Promise<GetCorrectedFileContentResult> {
let originalContent = ''; let originalContent = '';
let fileExists = false; let fileExists = false;
@ -316,6 +320,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
file_path: filePath, file_path: filePath,
}, },
this.client, this.client,
abortSignal,
); );
correctedContent = correctedParams.new_string; correctedContent = correctedParams.new_string;
} else { } else {
@ -323,6 +328,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
correctedContent = await ensureCorrectFileContent( correctedContent = await ensureCorrectFileContent(
proposedContent, proposedContent,
this.client, this.client,
abortSignal,
); );
} }
return { originalContent, correctedContent, fileExists }; return { originalContent, correctedContent, fileExists };

View File

@ -132,6 +132,7 @@ describe('editCorrector', () => {
let mockGeminiClientInstance: Mocked<GeminiClient>; let mockGeminiClientInstance: Mocked<GeminiClient>;
let mockToolRegistry: Mocked<ToolRegistry>; let mockToolRegistry: Mocked<ToolRegistry>;
let mockConfigInstance: Config; let mockConfigInstance: Config;
const abortSignal = new AbortController().signal;
beforeEach(() => { beforeEach(() => {
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>; mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
@ -187,12 +188,18 @@ describe('editCorrector', () => {
callCount = 0; callCount = 0;
mockResponses.length = 0; mockResponses.length = 0;
mockGenerateJson = vi.fn().mockImplementation(() => { mockGenerateJson = vi
const response = mockResponses[callCount]; .fn()
callCount++; .mockImplementation((_contents, _schema, signal) => {
if (response === undefined) return Promise.resolve({}); // Check if the signal is aborted. If so, throw an error or return a specific response.
return Promise.resolve(response); if (signal && signal.aborted) {
}); return Promise.reject(new Error('Aborted')); // Or some other specific error/response
}
const response = mockResponses[callCount];
callCount++;
if (response === undefined) return Promise.resolve({});
return Promise.resolve(response);
});
mockStartChat = vi.fn(); mockStartChat = vi.fn();
mockSendMessageStream = vi.fn(); mockSendMessageStream = vi.fn();
@ -217,6 +224,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"'); expect(result.params.new_string).toBe('replace with "this"');
@ -234,6 +242,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this'); expect(result.params.new_string).toBe('replace with this');
@ -254,6 +263,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"'); expect(result.params.new_string).toBe('replace with "this"');
@ -271,6 +281,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this'); expect(result.params.new_string).toBe('replace with this');
@ -292,6 +303,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"'); expect(result.params.new_string).toBe('replace with "this"');
@ -309,6 +321,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params.new_string).toBe('replace with this'); expect(result.params.new_string).toBe('replace with this');
@ -329,6 +342,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with foobar'); expect(result.params.new_string).toBe('replace with foobar');
@ -351,6 +365,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe(llmNewString); expect(result.params.new_string).toBe(llmNewString);
@ -372,6 +387,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(2); expect(mockGenerateJson).toHaveBeenCalledTimes(2);
expect(result.params.new_string).toBe(llmNewString); expect(result.params.new_string).toBe(llmNewString);
@ -391,6 +407,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe('replace with "this"'); expect(result.params.new_string).toBe('replace with "this"');
@ -412,6 +429,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM); expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM);
@ -432,6 +450,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(1); expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(result.params).toEqual(originalParams); expect(result.params).toEqual(originalParams);
@ -449,6 +468,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(0); expect(mockGenerateJson).toHaveBeenCalledTimes(0);
expect(result.params).toEqual(originalParams); expect(result.params).toEqual(originalParams);
@ -471,6 +491,7 @@ describe('editCorrector', () => {
currentContent, currentContent,
originalParams, originalParams,
mockGeminiClientInstance, mockGeminiClientInstance,
abortSignal,
); );
expect(mockGenerateJson).toHaveBeenCalledTimes(2); expect(mockGenerateJson).toHaveBeenCalledTimes(2);
expect(result.params.old_string).toBe(currentContent); expect(result.params.old_string).toBe(currentContent);

View File

@ -63,6 +63,7 @@ export async function ensureCorrectEdit(
currentContent: string, currentContent: string,
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\' originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
client: GeminiClient, client: GeminiClient,
abortSignal: AbortSignal,
): Promise<CorrectedEditResult> { ): Promise<CorrectedEditResult> {
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`; const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
const cachedResult = editCorrectionCache.get(cacheKey); const cachedResult = editCorrectionCache.get(cacheKey);
@ -84,6 +85,7 @@ export async function ensureCorrectEdit(
client, client,
finalOldString, finalOldString,
originalParams.new_string, originalParams.new_string,
abortSignal,
); );
} }
} else if (occurrences > 1) { } else if (occurrences > 1) {
@ -108,6 +110,7 @@ export async function ensureCorrectEdit(
originalParams.old_string, // original old originalParams.old_string, // original old
unescapedOldStringAttempt, // corrected old unescapedOldStringAttempt, // corrected old
originalParams.new_string, // original new (which is potentially escaped) originalParams.new_string, // original new (which is potentially escaped)
abortSignal,
); );
} }
} else if (occurrences === 0) { } else if (occurrences === 0) {
@ -115,6 +118,7 @@ export async function ensureCorrectEdit(
client, client,
currentContent, currentContent,
unescapedOldStringAttempt, unescapedOldStringAttempt,
abortSignal,
); );
const llmOldOccurrences = countOccurrences( const llmOldOccurrences = countOccurrences(
currentContent, currentContent,
@ -134,6 +138,7 @@ export async function ensureCorrectEdit(
originalParams.old_string, // original old originalParams.old_string, // original old
llmCorrectedOldString, // corrected old llmCorrectedOldString, // corrected old
baseNewStringForLLMCorrection, // base new for correction baseNewStringForLLMCorrection, // base new for correction
abortSignal,
); );
} }
} else { } else {
@ -180,6 +185,7 @@ export async function ensureCorrectEdit(
export async function ensureCorrectFileContent( export async function ensureCorrectFileContent(
content: string, content: string,
client: GeminiClient, client: GeminiClient,
abortSignal: AbortSignal,
): Promise<string> { ): Promise<string> {
const cachedResult = fileContentCorrectionCache.get(content); const cachedResult = fileContentCorrectionCache.get(content);
if (cachedResult) { if (cachedResult) {
@ -193,7 +199,11 @@ export async function ensureCorrectFileContent(
return content; return content;
} }
const correctedContent = await correctStringEscaping(content, client); const correctedContent = await correctStringEscaping(
content,
client,
abortSignal,
);
fileContentCorrectionCache.set(content, correctedContent); fileContentCorrectionCache.set(content, correctedContent);
return correctedContent; return correctedContent;
} }
@ -215,6 +225,7 @@ export async function correctOldStringMismatch(
geminiClient: GeminiClient, geminiClient: GeminiClient,
fileContent: string, fileContent: string,
problematicSnippet: string, problematicSnippet: string,
abortSignal: AbortSignal,
): Promise<string> { ): Promise<string> {
const prompt = ` const prompt = `
Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped. Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped.
@ -243,6 +254,7 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
const result = await geminiClient.generateJson( const result = await geminiClient.generateJson(
contents, contents,
OLD_STRING_CORRECTION_SCHEMA, OLD_STRING_CORRECTION_SCHEMA,
abortSignal,
EditModel, EditModel,
EditConfig, EditConfig,
); );
@ -257,10 +269,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
return problematicSnippet; return problematicSnippet;
} }
} catch (error) { } catch (error) {
if (abortSignal.aborted) {
throw error;
}
console.error( console.error(
'Error during LLM call for old string snippet correction:', 'Error during LLM call for old string snippet correction:',
error, error,
); );
return problematicSnippet; return problematicSnippet;
} }
} }
@ -286,6 +303,7 @@ export async function correctNewString(
originalOldString: string, originalOldString: string,
correctedOldString: string, correctedOldString: string,
originalNewString: string, originalNewString: string,
abortSignal: AbortSignal,
): Promise<string> { ): Promise<string> {
if (originalOldString === correctedOldString) { if (originalOldString === correctedOldString) {
return originalNewString; return originalNewString;
@ -324,6 +342,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const result = await geminiClient.generateJson( const result = await geminiClient.generateJson(
contents, contents,
NEW_STRING_CORRECTION_SCHEMA, NEW_STRING_CORRECTION_SCHEMA,
abortSignal,
EditModel, EditModel,
EditConfig, EditConfig,
); );
@ -338,6 +357,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
return originalNewString; return originalNewString;
} }
} catch (error) { } catch (error) {
if (abortSignal.aborted) {
throw error;
}
console.error('Error during LLM call for new_string correction:', error); console.error('Error during LLM call for new_string correction:', error);
return originalNewString; return originalNewString;
} }
@ -359,6 +382,7 @@ export async function correctNewStringEscaping(
geminiClient: GeminiClient, geminiClient: GeminiClient,
oldString: string, oldString: string,
potentiallyProblematicNewString: string, potentiallyProblematicNewString: string,
abortSignal: AbortSignal,
): Promise<string> { ): Promise<string> {
const prompt = ` const prompt = `
Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello"). Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
@ -387,6 +411,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const result = await geminiClient.generateJson( const result = await geminiClient.generateJson(
contents, contents,
CORRECT_NEW_STRING_ESCAPING_SCHEMA, CORRECT_NEW_STRING_ESCAPING_SCHEMA,
abortSignal,
EditModel, EditModel,
EditConfig, EditConfig,
); );
@ -401,6 +426,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
return potentiallyProblematicNewString; return potentiallyProblematicNewString;
} }
} catch (error) { } catch (error) {
if (abortSignal.aborted) {
throw error;
}
console.error( console.error(
'Error during LLM call for new_string escaping correction:', 'Error during LLM call for new_string escaping correction:',
error, error,
@ -424,6 +453,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = {
export async function correctStringEscaping( export async function correctStringEscaping(
potentiallyProblematicString: string, potentiallyProblematicString: string,
client: GeminiClient, client: GeminiClient,
abortSignal: AbortSignal,
): Promise<string> { ): Promise<string> {
const prompt = ` const prompt = `
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello"). Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
@ -447,6 +477,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const result = await client.generateJson( const result = await client.generateJson(
contents, contents,
CORRECT_STRING_ESCAPING_SCHEMA, CORRECT_STRING_ESCAPING_SCHEMA,
abortSignal,
EditModel, EditModel,
EditConfig, EditConfig,
); );
@ -461,6 +492,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
return potentiallyProblematicString; return potentiallyProblematicString;
} }
} catch (error) { } catch (error) {
if (abortSignal.aborted) {
throw error;
}
console.error( console.error(
'Error during LLM call for string escaping correction:', 'Error during LLM call for string escaping correction:',
error, error,

View File

@ -44,6 +44,7 @@ describe('checkNextSpeaker', () => {
let chatInstance: GeminiChat; let chatInstance: GeminiChat;
let mockGeminiClient: GeminiClient; let mockGeminiClient: GeminiClient;
let MockConfig: Mock; let MockConfig: Mock;
const abortSignal = new AbortController().signal;
beforeEach(() => { beforeEach(() => {
MockConfig = vi.mocked(Config); MockConfig = vi.mocked(Config);
@ -71,7 +72,7 @@ describe('checkNextSpeaker', () => {
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
'gemini-pro', // model name 'gemini-pro', // model name
{}, // config {},
[], // initial history [], // initial history
); );
@ -85,7 +86,11 @@ describe('checkNextSpeaker', () => {
it('should return null if history is empty', async () => { it('should return null if history is empty', async () => {
(chatInstance.getHistory as Mock).mockReturnValue([]); (chatInstance.getHistory as Mock).mockReturnValue([]);
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull(); expect(result).toBeNull();
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
}); });
@ -94,7 +99,11 @@ describe('checkNextSpeaker', () => {
(chatInstance.getHistory as Mock).mockReturnValue([ (chatInstance.getHistory as Mock).mockReturnValue([
{ role: 'user', parts: [{ text: 'Hello' }] }, { role: 'user', parts: [{ text: 'Hello' }] },
] as Content[]); ] as Content[]);
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull(); expect(result).toBeNull();
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled(); expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
}); });
@ -109,7 +118,11 @@ describe('checkNextSpeaker', () => {
}; };
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toEqual(mockApiResponse); expect(result).toEqual(mockApiResponse);
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1); expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
}); });
@ -124,7 +137,11 @@ describe('checkNextSpeaker', () => {
}; };
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toEqual(mockApiResponse); expect(result).toEqual(mockApiResponse);
}); });
@ -138,7 +155,11 @@ describe('checkNextSpeaker', () => {
}; };
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse); (mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toEqual(mockApiResponse); expect(result).toEqual(mockApiResponse);
}); });
@ -153,7 +174,11 @@ describe('checkNextSpeaker', () => {
new Error('API Error'), new Error('API Error'),
); );
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull(); expect(result).toBeNull();
consoleWarnSpy.mockRestore(); consoleWarnSpy.mockRestore();
}); });
@ -166,7 +191,11 @@ describe('checkNextSpeaker', () => {
reasoning: 'This is incomplete.', reasoning: 'This is incomplete.',
} as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response } as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull(); expect(result).toBeNull();
}); });
@ -179,7 +208,11 @@ describe('checkNextSpeaker', () => {
next_speaker: 123, // Invalid type next_speaker: 123, // Invalid type
} as unknown as NextSpeakerResponse); } as unknown as NextSpeakerResponse);
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull(); expect(result).toBeNull();
}); });
@ -192,7 +225,11 @@ describe('checkNextSpeaker', () => {
next_speaker: 'neither', // Invalid enum value next_speaker: 'neither', // Invalid enum value
} as unknown as NextSpeakerResponse); } as unknown as NextSpeakerResponse);
const result = await checkNextSpeaker(chatInstance, mockGeminiClient); const result = await checkNextSpeaker(
chatInstance,
mockGeminiClient,
abortSignal,
);
expect(result).toBeNull(); expect(result).toBeNull();
}); });
}); });

View File

@ -61,6 +61,7 @@ export interface NextSpeakerResponse {
export async function checkNextSpeaker( export async function checkNextSpeaker(
chat: GeminiChat, chat: GeminiChat,
geminiClient: GeminiClient, geminiClient: GeminiClient,
abortSignal: AbortSignal,
): Promise<NextSpeakerResponse | null> { ): Promise<NextSpeakerResponse | null> {
// We need to capture the curated history because there are many moments when the model will return invalid turns // We need to capture the curated history because there are many moments when the model will return invalid turns
// that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides // that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides
@ -129,6 +130,7 @@ export async function checkNextSpeaker(
const parsedResponse = (await geminiClient.generateJson( const parsedResponse = (await geminiClient.generateJson(
contents, contents,
RESPONSE_SCHEMA, RESPONSE_SCHEMA,
abortSignal,
)) as unknown as NextSpeakerResponse; )) as unknown as NextSpeakerResponse;
if ( if (