feat: Allow cancellation of in-progress Gemini requests and pre-execution checks
- Implements cancellation for Gemini requests while they are actively being processed by the model. - Extends cancellation support to the logic within tools. This allows users to cancel operations during the phase where the system is determining if a tool execution requires user confirmation, which can include potentially long-running pre-flight checks or LLM-based corrections. - Underlying LLM calls for edit corrections (within and ) and next speaker checks can now also be cancelled. - Previously, cancellation of the main request was not possible until text started streaming, and pre-execution checks were not cancellable. - This change leverages the updated SDK's ability to accept an abort token and threads s throughout the request, tool execution, and pre-execution check lifecycle. Fixes https://github.com/google-gemini/gemini-cli/issues/531
This commit is contained in:
parent
bfeaac8441
commit
f2f2ecf9d8
|
@ -95,9 +95,11 @@ async function main() {
|
||||||
const geminiClient = new GeminiClient(config);
|
const geminiClient = new GeminiClient(config);
|
||||||
const chat = await geminiClient.startChat();
|
const chat = await geminiClient.startChat();
|
||||||
try {
|
try {
|
||||||
for await (const event of geminiClient.sendMessageStream(chat, [
|
for await (const event of geminiClient.sendMessageStream(
|
||||||
{ text: input },
|
chat,
|
||||||
])) {
|
[{ text: input }],
|
||||||
|
new AbortController().signal,
|
||||||
|
)) {
|
||||||
if (event.type === 'content') {
|
if (event.type === 'content') {
|
||||||
process.stdout.write(event.value);
|
process.stdout.write(event.value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -142,7 +142,10 @@ export function useToolScheduler(
|
||||||
|
|
||||||
const { request: r, tool } = initialCall;
|
const { request: r, tool } = initialCall;
|
||||||
try {
|
try {
|
||||||
const userApproval = await tool.shouldConfirmExecute(r.args);
|
const userApproval = await tool.shouldConfirmExecute(
|
||||||
|
r.args,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
if (userApproval) {
|
if (userApproval) {
|
||||||
// Confirmation is needed. Update status to 'awaiting_approval'.
|
// Confirmation is needed. Update status to 'awaiting_approval'.
|
||||||
setToolCalls(
|
setToolCalls(
|
||||||
|
@ -183,7 +186,7 @@ export function useToolScheduler(
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[isRunning, setToolCalls, toolRegistry],
|
[isRunning, setToolCalls, toolRegistry, abortController.signal],
|
||||||
);
|
);
|
||||||
|
|
||||||
const cancel = useCallback(
|
const cancel = useCallback(
|
||||||
|
|
|
@ -157,7 +157,7 @@ export class GeminiClient {
|
||||||
async *sendMessageStream(
|
async *sendMessageStream(
|
||||||
chat: GeminiChat,
|
chat: GeminiChat,
|
||||||
request: PartListUnion,
|
request: PartListUnion,
|
||||||
signal?: AbortSignal,
|
signal: AbortSignal,
|
||||||
turns: number = this.MAX_TURNS,
|
turns: number = this.MAX_TURNS,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||||
if (!turns) {
|
if (!turns) {
|
||||||
|
@ -169,8 +169,8 @@ export class GeminiClient {
|
||||||
for await (const event of resultStream) {
|
for await (const event of resultStream) {
|
||||||
yield event;
|
yield event;
|
||||||
}
|
}
|
||||||
if (!turn.pendingToolCalls.length) {
|
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||||
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
|
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
|
||||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||||
const nextRequest = [{ text: 'Please continue.' }];
|
const nextRequest = [{ text: 'Please continue.' }];
|
||||||
yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1);
|
yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1);
|
||||||
|
@ -181,6 +181,7 @@ export class GeminiClient {
|
||||||
async generateJson(
|
async generateJson(
|
||||||
contents: Content[],
|
contents: Content[],
|
||||||
schema: SchemaUnion,
|
schema: SchemaUnion,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
model: string = 'gemini-2.0-flash',
|
model: string = 'gemini-2.0-flash',
|
||||||
config: GenerateContentConfig = {},
|
config: GenerateContentConfig = {},
|
||||||
): Promise<Record<string, unknown>> {
|
): Promise<Record<string, unknown>> {
|
||||||
|
@ -188,6 +189,7 @@ export class GeminiClient {
|
||||||
const userMemory = this.config.getUserMemory();
|
const userMemory = this.config.getUserMemory();
|
||||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||||
const requestConfig = {
|
const requestConfig = {
|
||||||
|
abortSignal,
|
||||||
...this.generateContentConfig,
|
...this.generateContentConfig,
|
||||||
...config,
|
...config,
|
||||||
};
|
};
|
||||||
|
@ -232,6 +234,11 @@ export class GeminiClient {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (abortSignal.aborted) {
|
||||||
|
// Regular cancellation error, fail normally
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
// Avoid double reporting for the empty response case handled above
|
// Avoid double reporting for the empty response case handled above
|
||||||
if (
|
if (
|
||||||
error instanceof Error &&
|
error instanceof Error &&
|
||||||
|
|
|
@ -155,7 +155,7 @@ export class GeminiChat {
|
||||||
const responsePromise = this.modelsModule.generateContent({
|
const responsePromise = this.modelsModule.generateContent({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
contents: this.getHistory(true).concat(userContent),
|
contents: this.getHistory(true).concat(userContent),
|
||||||
config: params.config ?? this.config,
|
config: { ...this.config, ...params.config },
|
||||||
});
|
});
|
||||||
this.sendPromise = (async () => {
|
this.sendPromise = (async () => {
|
||||||
const response = await responsePromise;
|
const response = await responsePromise;
|
||||||
|
@ -219,7 +219,7 @@ export class GeminiChat {
|
||||||
const streamResponse = this.modelsModule.generateContentStream({
|
const streamResponse = this.modelsModule.generateContentStream({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
contents: this.getHistory(true).concat(userContent),
|
contents: this.getHistory(true).concat(userContent),
|
||||||
config: params.config ?? this.config,
|
config: { ...this.config, ...params.config },
|
||||||
});
|
});
|
||||||
// Resolve the internal tracking of send completion promise - `sendPromise`
|
// Resolve the internal tracking of send completion promise - `sendPromise`
|
||||||
// for both success and failure response. The actual failure is still
|
// for both success and failure response. The actual failure is still
|
||||||
|
|
|
@ -85,11 +85,17 @@ describe('Turn', () => {
|
||||||
|
|
||||||
const events = [];
|
const events = [];
|
||||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||||
for await (const event of turn.run(reqParts)) {
|
for await (const event of turn.run(
|
||||||
|
reqParts,
|
||||||
|
new AbortController().signal,
|
||||||
|
)) {
|
||||||
events.push(event);
|
events.push(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts });
|
expect(mockSendMessageStream).toHaveBeenCalledWith({
|
||||||
|
message: reqParts,
|
||||||
|
config: { abortSignal: expect.any(AbortSignal) },
|
||||||
|
});
|
||||||
expect(events).toEqual([
|
expect(events).toEqual([
|
||||||
{ type: GeminiEventType.Content, value: 'Hello' },
|
{ type: GeminiEventType.Content, value: 'Hello' },
|
||||||
{ type: GeminiEventType.Content, value: ' world' },
|
{ type: GeminiEventType.Content, value: ' world' },
|
||||||
|
@ -110,7 +116,10 @@ describe('Turn', () => {
|
||||||
|
|
||||||
const events = [];
|
const events = [];
|
||||||
const reqParts: Part[] = [{ text: 'Use tools' }];
|
const reqParts: Part[] = [{ text: 'Use tools' }];
|
||||||
for await (const event of turn.run(reqParts)) {
|
for await (const event of turn.run(
|
||||||
|
reqParts,
|
||||||
|
new AbortController().signal,
|
||||||
|
)) {
|
||||||
events.push(event);
|
events.push(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,7 +188,10 @@ describe('Turn', () => {
|
||||||
mockGetHistory.mockReturnValue(historyContent);
|
mockGetHistory.mockReturnValue(historyContent);
|
||||||
|
|
||||||
const events = [];
|
const events = [];
|
||||||
for await (const event of turn.run(reqParts)) {
|
for await (const event of turn.run(
|
||||||
|
reqParts,
|
||||||
|
new AbortController().signal,
|
||||||
|
)) {
|
||||||
events.push(event);
|
events.push(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,7 +222,10 @@ describe('Turn', () => {
|
||||||
|
|
||||||
const events = [];
|
const events = [];
|
||||||
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
|
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
|
||||||
for await (const event of turn.run(reqParts)) {
|
for await (const event of turn.run(
|
||||||
|
reqParts,
|
||||||
|
new AbortController().signal,
|
||||||
|
)) {
|
||||||
events.push(event);
|
events.push(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -261,7 +276,7 @@ describe('Turn', () => {
|
||||||
})();
|
})();
|
||||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||||
for await (const _ of turn.run(reqParts)) {
|
for await (const _ of turn.run(reqParts, new AbortController().signal)) {
|
||||||
// consume stream
|
// consume stream
|
||||||
}
|
}
|
||||||
expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
|
expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
|
||||||
|
|
|
@ -32,6 +32,7 @@ export interface ServerTool {
|
||||||
): Promise<ToolResult>;
|
): Promise<ToolResult>;
|
||||||
shouldConfirmExecute(
|
shouldConfirmExecute(
|
||||||
params: Record<string, unknown>,
|
params: Record<string, unknown>,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<ToolCallConfirmationDetails | false>;
|
): Promise<ToolCallConfirmationDetails | false>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,11 +121,14 @@ export class Turn {
|
||||||
// The run method yields simpler events suitable for server logic
|
// The run method yields simpler events suitable for server logic
|
||||||
async *run(
|
async *run(
|
||||||
req: PartListUnion,
|
req: PartListUnion,
|
||||||
signal?: AbortSignal,
|
signal: AbortSignal,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||||
try {
|
try {
|
||||||
const responseStream = await this.chat.sendMessageStream({
|
const responseStream = await this.chat.sendMessageStream({
|
||||||
message: req,
|
message: req,
|
||||||
|
config: {
|
||||||
|
abortSignal: signal,
|
||||||
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
for await (const resp of responseStream) {
|
for await (const resp of responseStream) {
|
||||||
|
@ -150,6 +154,12 @@ export class Turn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (signal.aborted) {
|
||||||
|
yield { type: GeminiEventType.UserCancelled };
|
||||||
|
// Regular cancellation error, fail gracefully.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
|
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
|
||||||
await reportError(
|
await reportError(
|
||||||
error,
|
error,
|
||||||
|
|
|
@ -223,7 +223,9 @@ describe('EditTool', () => {
|
||||||
old_string: 'old',
|
old_string: 'old',
|
||||||
new_string: 'new',
|
new_string: 'new',
|
||||||
};
|
};
|
||||||
expect(await tool.shouldConfirmExecute(params)).toBe(false);
|
expect(
|
||||||
|
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||||
|
).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should request confirmation for valid edit', async () => {
|
it('should request confirmation for valid edit', async () => {
|
||||||
|
@ -235,7 +237,10 @@ describe('EditTool', () => {
|
||||||
};
|
};
|
||||||
// ensureCorrectEdit will be called by shouldConfirmExecute
|
// ensureCorrectEdit will be called by shouldConfirmExecute
|
||||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 });
|
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 1 });
|
||||||
const confirmation = await tool.shouldConfirmExecute(params);
|
const confirmation = await tool.shouldConfirmExecute(
|
||||||
|
params,
|
||||||
|
new AbortController().signal,
|
||||||
|
);
|
||||||
expect(confirmation).toEqual(
|
expect(confirmation).toEqual(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
title: `Confirm Edit: ${testFile}`,
|
title: `Confirm Edit: ${testFile}`,
|
||||||
|
@ -253,7 +258,9 @@ describe('EditTool', () => {
|
||||||
new_string: 'new',
|
new_string: 'new',
|
||||||
};
|
};
|
||||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
||||||
expect(await tool.shouldConfirmExecute(params)).toBe(false);
|
expect(
|
||||||
|
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||||
|
).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => {
|
it('should return false if multiple occurrences of old_string are found (ensureCorrectEdit returns > 1)', async () => {
|
||||||
|
@ -264,7 +271,9 @@ describe('EditTool', () => {
|
||||||
new_string: 'new',
|
new_string: 'new',
|
||||||
};
|
};
|
||||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 });
|
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 2 });
|
||||||
expect(await tool.shouldConfirmExecute(params)).toBe(false);
|
expect(
|
||||||
|
await tool.shouldConfirmExecute(params, new AbortController().signal),
|
||||||
|
).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should request confirmation for creating a new file (empty old_string)', async () => {
|
it('should request confirmation for creating a new file (empty old_string)', async () => {
|
||||||
|
@ -279,7 +288,10 @@ describe('EditTool', () => {
|
||||||
// as shouldConfirmExecute handles this for diff generation.
|
// as shouldConfirmExecute handles this for diff generation.
|
||||||
// If it is called, it should return 0 occurrences for a new file.
|
// If it is called, it should return 0 occurrences for a new file.
|
||||||
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
mockEnsureCorrectEdit.mockResolvedValueOnce({ params, occurrences: 0 });
|
||||||
const confirmation = await tool.shouldConfirmExecute(params);
|
const confirmation = await tool.shouldConfirmExecute(
|
||||||
|
params,
|
||||||
|
new AbortController().signal,
|
||||||
|
);
|
||||||
expect(confirmation).toEqual(
|
expect(confirmation).toEqual(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
title: `Confirm Edit: ${newFileName}`,
|
title: `Confirm Edit: ${newFileName}`,
|
||||||
|
@ -328,6 +340,7 @@ describe('EditTool', () => {
|
||||||
|
|
||||||
const confirmation = (await tool.shouldConfirmExecute(
|
const confirmation = (await tool.shouldConfirmExecute(
|
||||||
params,
|
params,
|
||||||
|
new AbortController().signal,
|
||||||
)) as FileDiff;
|
)) as FileDiff;
|
||||||
|
|
||||||
expect(mockCalled).toBe(true); // Check if the mock implementation was run
|
expect(mockCalled).toBe(true); // Check if the mock implementation was run
|
||||||
|
|
|
@ -174,7 +174,10 @@ Expectation for parameters:
|
||||||
* @returns An object describing the potential edit outcome
|
* @returns An object describing the potential edit outcome
|
||||||
* @throws File system errors if reading the file fails unexpectedly (e.g., permissions)
|
* @throws File system errors if reading the file fails unexpectedly (e.g., permissions)
|
||||||
*/
|
*/
|
||||||
private async calculateEdit(params: EditToolParams): Promise<CalculatedEdit> {
|
private async calculateEdit(
|
||||||
|
params: EditToolParams,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
|
): Promise<CalculatedEdit> {
|
||||||
const expectedReplacements = 1;
|
const expectedReplacements = 1;
|
||||||
let currentContent: string | null = null;
|
let currentContent: string | null = null;
|
||||||
let fileExists = false;
|
let fileExists = false;
|
||||||
|
@ -210,6 +213,7 @@ Expectation for parameters:
|
||||||
currentContent,
|
currentContent,
|
||||||
params,
|
params,
|
||||||
this.client,
|
this.client,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
finalOldString = correctedEdit.params.old_string;
|
finalOldString = correctedEdit.params.old_string;
|
||||||
finalNewString = correctedEdit.params.new_string;
|
finalNewString = correctedEdit.params.new_string;
|
||||||
|
@ -262,6 +266,7 @@ Expectation for parameters:
|
||||||
*/
|
*/
|
||||||
async shouldConfirmExecute(
|
async shouldConfirmExecute(
|
||||||
params: EditToolParams,
|
params: EditToolParams,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<ToolCallConfirmationDetails | false> {
|
): Promise<ToolCallConfirmationDetails | false> {
|
||||||
if (this.config.getAlwaysSkipModificationConfirmation()) {
|
if (this.config.getAlwaysSkipModificationConfirmation()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -300,6 +305,7 @@ Expectation for parameters:
|
||||||
currentContent,
|
currentContent,
|
||||||
params,
|
params,
|
||||||
this.client,
|
this.client,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
finalOldString = correctedEdit.params.old_string;
|
finalOldString = correctedEdit.params.old_string;
|
||||||
finalNewString = correctedEdit.params.new_string;
|
finalNewString = correctedEdit.params.new_string;
|
||||||
|
@ -376,7 +382,7 @@ Expectation for parameters:
|
||||||
|
|
||||||
let editData: CalculatedEdit;
|
let editData: CalculatedEdit;
|
||||||
try {
|
try {
|
||||||
editData = await this.calculateEdit(params);
|
editData = await this.calculateEdit(params, _signal);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -98,6 +98,7 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> {
|
||||||
|
|
||||||
async shouldConfirmExecute(
|
async shouldConfirmExecute(
|
||||||
params: ShellToolParams,
|
params: ShellToolParams,
|
||||||
|
_abortSignal: AbortSignal,
|
||||||
): Promise<ToolCallConfirmationDetails | false> {
|
): Promise<ToolCallConfirmationDetails | false> {
|
||||||
if (this.validateToolParams(params)) {
|
if (this.validateToolParams(params)) {
|
||||||
return false; // skip confirmation, execute call will fail immediately
|
return false; // skip confirmation, execute call will fail immediately
|
||||||
|
|
|
@ -57,6 +57,7 @@ export interface Tool<
|
||||||
*/
|
*/
|
||||||
shouldConfirmExecute(
|
shouldConfirmExecute(
|
||||||
params: TParams,
|
params: TParams,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<ToolCallConfirmationDetails | false>;
|
): Promise<ToolCallConfirmationDetails | false>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -137,6 +138,8 @@ export abstract class BaseTool<
|
||||||
shouldConfirmExecute(
|
shouldConfirmExecute(
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
params: TParams,
|
params: TParams,
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<ToolCallConfirmationDetails | false> {
|
): Promise<ToolCallConfirmationDetails | false> {
|
||||||
return Promise.resolve(false);
|
return Promise.resolve(false);
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,18 +110,32 @@ describe('WriteFileTool', () => {
|
||||||
// Default mock implementations that return valid structures
|
// Default mock implementations that return valid structures
|
||||||
mockEnsureCorrectEdit.mockImplementation(
|
mockEnsureCorrectEdit.mockImplementation(
|
||||||
async (
|
async (
|
||||||
currentContent: string,
|
_currentContent: string,
|
||||||
params: EditToolParams,
|
params: EditToolParams,
|
||||||
_client: GeminiClient,
|
_client: GeminiClient,
|
||||||
): Promise<CorrectedEditResult> =>
|
signal?: AbortSignal, // Make AbortSignal optional to match usage
|
||||||
Promise.resolve({
|
): Promise<CorrectedEditResult> => {
|
||||||
|
if (signal?.aborted) {
|
||||||
|
return Promise.reject(new Error('Aborted'));
|
||||||
|
}
|
||||||
|
return Promise.resolve({
|
||||||
params: { ...params, new_string: params.new_string ?? '' },
|
params: { ...params, new_string: params.new_string ?? '' },
|
||||||
occurrences: 1,
|
occurrences: 1,
|
||||||
}),
|
});
|
||||||
|
},
|
||||||
);
|
);
|
||||||
mockEnsureCorrectFileContent.mockImplementation(
|
mockEnsureCorrectFileContent.mockImplementation(
|
||||||
async (content: string, _client: GeminiClient): Promise<string> =>
|
async (
|
||||||
Promise.resolve(content ?? ''),
|
content: string,
|
||||||
|
_client: GeminiClient,
|
||||||
|
signal?: AbortSignal,
|
||||||
|
): Promise<string> => {
|
||||||
|
// Make AbortSignal optional
|
||||||
|
if (signal?.aborted) {
|
||||||
|
return Promise.reject(new Error('Aborted'));
|
||||||
|
}
|
||||||
|
return Promise.resolve(content ?? '');
|
||||||
|
},
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -181,6 +195,7 @@ describe('WriteFileTool', () => {
|
||||||
const filePath = path.join(rootDir, 'new_corrected_file.txt');
|
const filePath = path.join(rootDir, 'new_corrected_file.txt');
|
||||||
const proposedContent = 'Proposed new content.';
|
const proposedContent = 'Proposed new content.';
|
||||||
const correctedContent = 'Corrected new content.';
|
const correctedContent = 'Corrected new content.';
|
||||||
|
const abortSignal = new AbortController().signal;
|
||||||
// Ensure the mock is set for this specific test case if needed, or rely on beforeEach
|
// Ensure the mock is set for this specific test case if needed, or rely on beforeEach
|
||||||
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
mockEnsureCorrectFileContent.mockResolvedValue(correctedContent);
|
||||||
|
|
||||||
|
@ -188,11 +203,13 @@ describe('WriteFileTool', () => {
|
||||||
const result = await tool._getCorrectedFileContent(
|
const result = await tool._getCorrectedFileContent(
|
||||||
filePath,
|
filePath,
|
||||||
proposedContent,
|
proposedContent,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||||
proposedContent,
|
proposedContent,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
||||||
expect(result.correctedContent).toBe(correctedContent);
|
expect(result.correctedContent).toBe(correctedContent);
|
||||||
|
@ -206,6 +223,7 @@ describe('WriteFileTool', () => {
|
||||||
const originalContent = 'Original existing content.';
|
const originalContent = 'Original existing content.';
|
||||||
const proposedContent = 'Proposed replacement content.';
|
const proposedContent = 'Proposed replacement content.';
|
||||||
const correctedProposedContent = 'Corrected replacement content.';
|
const correctedProposedContent = 'Corrected replacement content.';
|
||||||
|
const abortSignal = new AbortController().signal;
|
||||||
fs.writeFileSync(filePath, originalContent, 'utf8');
|
fs.writeFileSync(filePath, originalContent, 'utf8');
|
||||||
|
|
||||||
// Ensure this mock is active and returns the correct structure
|
// Ensure this mock is active and returns the correct structure
|
||||||
|
@ -222,6 +240,7 @@ describe('WriteFileTool', () => {
|
||||||
const result = await tool._getCorrectedFileContent(
|
const result = await tool._getCorrectedFileContent(
|
||||||
filePath,
|
filePath,
|
||||||
proposedContent,
|
proposedContent,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||||
|
@ -232,6 +251,7 @@ describe('WriteFileTool', () => {
|
||||||
file_path: filePath,
|
file_path: filePath,
|
||||||
},
|
},
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
||||||
expect(result.correctedContent).toBe(correctedProposedContent);
|
expect(result.correctedContent).toBe(correctedProposedContent);
|
||||||
|
@ -243,6 +263,7 @@ describe('WriteFileTool', () => {
|
||||||
it('should return error if reading an existing file fails (e.g. permissions)', async () => {
|
it('should return error if reading an existing file fails (e.g. permissions)', async () => {
|
||||||
const filePath = path.join(rootDir, 'unreadable_file.txt');
|
const filePath = path.join(rootDir, 'unreadable_file.txt');
|
||||||
const proposedContent = 'some content';
|
const proposedContent = 'some content';
|
||||||
|
const abortSignal = new AbortController().signal;
|
||||||
fs.writeFileSync(filePath, 'content', { mode: 0o000 });
|
fs.writeFileSync(filePath, 'content', { mode: 0o000 });
|
||||||
|
|
||||||
const readError = new Error('Permission denied');
|
const readError = new Error('Permission denied');
|
||||||
|
@ -255,6 +276,7 @@ describe('WriteFileTool', () => {
|
||||||
const result = await tool._getCorrectedFileContent(
|
const result = await tool._getCorrectedFileContent(
|
||||||
filePath,
|
filePath,
|
||||||
proposedContent,
|
proposedContent,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8');
|
expect(fs.readFileSync).toHaveBeenCalledWith(filePath, 'utf8');
|
||||||
|
@ -274,16 +296,17 @@ describe('WriteFileTool', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('shouldConfirmExecute', () => {
|
describe('shouldConfirmExecute', () => {
|
||||||
|
const abortSignal = new AbortController().signal;
|
||||||
it('should return false if params are invalid (relative path)', async () => {
|
it('should return false if params are invalid (relative path)', async () => {
|
||||||
const params = { file_path: 'relative.txt', content: 'test' };
|
const params = { file_path: 'relative.txt', content: 'test' };
|
||||||
const confirmation = await tool.shouldConfirmExecute(params);
|
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||||
expect(confirmation).toBe(false);
|
expect(confirmation).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return false if params are invalid (outside root)', async () => {
|
it('should return false if params are invalid (outside root)', async () => {
|
||||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||||
const params = { file_path: outsidePath, content: 'test' };
|
const params = { file_path: outsidePath, content: 'test' };
|
||||||
const confirmation = await tool.shouldConfirmExecute(params);
|
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||||
expect(confirmation).toBe(false);
|
expect(confirmation).toBe(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -298,7 +321,7 @@ describe('WriteFileTool', () => {
|
||||||
throw readError;
|
throw readError;
|
||||||
});
|
});
|
||||||
|
|
||||||
const confirmation = await tool.shouldConfirmExecute(params);
|
const confirmation = await tool.shouldConfirmExecute(params, abortSignal);
|
||||||
expect(confirmation).toBe(false);
|
expect(confirmation).toBe(false);
|
||||||
|
|
||||||
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
vi.spyOn(fs, 'readFileSync').mockImplementation(originalReadFileSync);
|
||||||
|
@ -314,11 +337,13 @@ describe('WriteFileTool', () => {
|
||||||
const params = { file_path: filePath, content: proposedContent };
|
const params = { file_path: filePath, content: proposedContent };
|
||||||
const confirmation = (await tool.shouldConfirmExecute(
|
const confirmation = (await tool.shouldConfirmExecute(
|
||||||
params,
|
params,
|
||||||
|
abortSignal,
|
||||||
)) as ToolEditConfirmationDetails;
|
)) as ToolEditConfirmationDetails;
|
||||||
|
|
||||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||||
proposedContent,
|
proposedContent,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(confirmation).toEqual(
|
expect(confirmation).toEqual(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
|
@ -343,7 +368,6 @@ describe('WriteFileTool', () => {
|
||||||
'Corrected replacement for confirmation.';
|
'Corrected replacement for confirmation.';
|
||||||
fs.writeFileSync(filePath, originalContent, 'utf8');
|
fs.writeFileSync(filePath, originalContent, 'utf8');
|
||||||
|
|
||||||
// Ensure this mock is active and returns the correct structure
|
|
||||||
mockEnsureCorrectEdit.mockResolvedValue({
|
mockEnsureCorrectEdit.mockResolvedValue({
|
||||||
params: {
|
params: {
|
||||||
file_path: filePath,
|
file_path: filePath,
|
||||||
|
@ -356,6 +380,7 @@ describe('WriteFileTool', () => {
|
||||||
const params = { file_path: filePath, content: proposedContent };
|
const params = { file_path: filePath, content: proposedContent };
|
||||||
const confirmation = (await tool.shouldConfirmExecute(
|
const confirmation = (await tool.shouldConfirmExecute(
|
||||||
params,
|
params,
|
||||||
|
abortSignal,
|
||||||
)) as ToolEditConfirmationDetails;
|
)) as ToolEditConfirmationDetails;
|
||||||
|
|
||||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||||
|
@ -366,6 +391,7 @@ describe('WriteFileTool', () => {
|
||||||
file_path: filePath,
|
file_path: filePath,
|
||||||
},
|
},
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(confirmation).toEqual(
|
expect(confirmation).toEqual(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
|
@ -381,9 +407,10 @@ describe('WriteFileTool', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('execute', () => {
|
describe('execute', () => {
|
||||||
|
const abortSignal = new AbortController().signal;
|
||||||
it('should return error if params are invalid (relative path)', async () => {
|
it('should return error if params are invalid (relative path)', async () => {
|
||||||
const params = { file_path: 'relative.txt', content: 'test' };
|
const params = { file_path: 'relative.txt', content: 'test' };
|
||||||
const result = await tool.execute(params, new AbortController().signal);
|
const result = await tool.execute(params, abortSignal);
|
||||||
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
||||||
expect(result.returnDisplay).toMatch(/Error: File path must be absolute/);
|
expect(result.returnDisplay).toMatch(/Error: File path must be absolute/);
|
||||||
});
|
});
|
||||||
|
@ -391,7 +418,7 @@ describe('WriteFileTool', () => {
|
||||||
it('should return error if params are invalid (path outside root)', async () => {
|
it('should return error if params are invalid (path outside root)', async () => {
|
||||||
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
const outsidePath = path.resolve(tempDir, 'outside-root.txt');
|
||||||
const params = { file_path: outsidePath, content: 'test' };
|
const params = { file_path: outsidePath, content: 'test' };
|
||||||
const result = await tool.execute(params, new AbortController().signal);
|
const result = await tool.execute(params, abortSignal);
|
||||||
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
expect(result.llmContent).toMatch(/Error: Invalid parameters provided/);
|
||||||
expect(result.returnDisplay).toMatch(
|
expect(result.returnDisplay).toMatch(
|
||||||
/Error: File path must be within the root directory/,
|
/Error: File path must be within the root directory/,
|
||||||
|
@ -409,7 +436,7 @@ describe('WriteFileTool', () => {
|
||||||
throw readError;
|
throw readError;
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await tool.execute(params, new AbortController().signal);
|
const result = await tool.execute(params, abortSignal);
|
||||||
expect(result.llmContent).toMatch(/Error checking existing file/);
|
expect(result.llmContent).toMatch(/Error checking existing file/);
|
||||||
expect(result.returnDisplay).toMatch(
|
expect(result.returnDisplay).toMatch(
|
||||||
/Error checking existing file: Simulated read error for execute/,
|
/Error checking existing file: Simulated read error for execute/,
|
||||||
|
@ -427,16 +454,20 @@ describe('WriteFileTool', () => {
|
||||||
|
|
||||||
const params = { file_path: filePath, content: proposedContent };
|
const params = { file_path: filePath, content: proposedContent };
|
||||||
|
|
||||||
const confirmDetails = await tool.shouldConfirmExecute(params);
|
const confirmDetails = await tool.shouldConfirmExecute(
|
||||||
|
params,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await tool.execute(params, new AbortController().signal);
|
const result = await tool.execute(params, abortSignal);
|
||||||
|
|
||||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||||
proposedContent,
|
proposedContent,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(result.llmContent).toMatch(
|
expect(result.llmContent).toMatch(
|
||||||
/Successfully created and wrote to new file/,
|
/Successfully created and wrote to new file/,
|
||||||
|
@ -477,12 +508,15 @@ describe('WriteFileTool', () => {
|
||||||
|
|
||||||
const params = { file_path: filePath, content: proposedContent };
|
const params = { file_path: filePath, content: proposedContent };
|
||||||
|
|
||||||
const confirmDetails = await tool.shouldConfirmExecute(params);
|
const confirmDetails = await tool.shouldConfirmExecute(
|
||||||
|
params,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await tool.execute(params, new AbortController().signal);
|
const result = await tool.execute(params, abortSignal);
|
||||||
|
|
||||||
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
expect(mockEnsureCorrectEdit).toHaveBeenCalledWith(
|
||||||
initialContent,
|
initialContent,
|
||||||
|
@ -492,6 +526,7 @@ describe('WriteFileTool', () => {
|
||||||
file_path: filePath,
|
file_path: filePath,
|
||||||
},
|
},
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(result.llmContent).toMatch(/Successfully overwrote file/);
|
expect(result.llmContent).toMatch(/Successfully overwrote file/);
|
||||||
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent);
|
expect(fs.readFileSync(filePath, 'utf8')).toBe(correctedProposedContent);
|
||||||
|
@ -513,12 +548,15 @@ describe('WriteFileTool', () => {
|
||||||
|
|
||||||
const params = { file_path: filePath, content };
|
const params = { file_path: filePath, content };
|
||||||
// Simulate confirmation if your logic requires it before execute, or remove if not needed for this path
|
// Simulate confirmation if your logic requires it before execute, or remove if not needed for this path
|
||||||
const confirmDetails = await tool.shouldConfirmExecute(params);
|
const confirmDetails = await tool.shouldConfirmExecute(
|
||||||
|
params,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
if (typeof confirmDetails === 'object' && confirmDetails.onConfirm) {
|
||||||
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
await confirmDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||||
}
|
}
|
||||||
|
|
||||||
await tool.execute(params, new AbortController().signal);
|
await tool.execute(params, abortSignal);
|
||||||
|
|
||||||
expect(fs.existsSync(dirPath)).toBe(true);
|
expect(fs.existsSync(dirPath)).toBe(true);
|
||||||
expect(fs.statSync(dirPath).isDirectory()).toBe(true);
|
expect(fs.statSync(dirPath).isDirectory()).toBe(true);
|
||||||
|
|
|
@ -141,6 +141,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
*/
|
*/
|
||||||
async shouldConfirmExecute(
|
async shouldConfirmExecute(
|
||||||
params: WriteFileToolParams,
|
params: WriteFileToolParams,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<ToolCallConfirmationDetails | false> {
|
): Promise<ToolCallConfirmationDetails | false> {
|
||||||
if (this.config.getAlwaysSkipModificationConfirmation()) {
|
if (this.config.getAlwaysSkipModificationConfirmation()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -154,6 +155,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
const correctedContentResult = await this._getCorrectedFileContent(
|
const correctedContentResult = await this._getCorrectedFileContent(
|
||||||
params.file_path,
|
params.file_path,
|
||||||
params.content,
|
params.content,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
|
|
||||||
if (correctedContentResult.error) {
|
if (correctedContentResult.error) {
|
||||||
|
@ -193,7 +195,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
|
|
||||||
async execute(
|
async execute(
|
||||||
params: WriteFileToolParams,
|
params: WriteFileToolParams,
|
||||||
_signal: AbortSignal,
|
abortSignal: AbortSignal,
|
||||||
): Promise<ToolResult> {
|
): Promise<ToolResult> {
|
||||||
const validationError = this.validateToolParams(params);
|
const validationError = this.validateToolParams(params);
|
||||||
if (validationError) {
|
if (validationError) {
|
||||||
|
@ -206,6 +208,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
const correctedContentResult = await this._getCorrectedFileContent(
|
const correctedContentResult = await this._getCorrectedFileContent(
|
||||||
params.file_path,
|
params.file_path,
|
||||||
params.content,
|
params.content,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
|
|
||||||
if (correctedContentResult.error) {
|
if (correctedContentResult.error) {
|
||||||
|
@ -277,6 +280,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
private async _getCorrectedFileContent(
|
private async _getCorrectedFileContent(
|
||||||
filePath: string,
|
filePath: string,
|
||||||
proposedContent: string,
|
proposedContent: string,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<GetCorrectedFileContentResult> {
|
): Promise<GetCorrectedFileContentResult> {
|
||||||
let originalContent = '';
|
let originalContent = '';
|
||||||
let fileExists = false;
|
let fileExists = false;
|
||||||
|
@ -316,6 +320,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
file_path: filePath,
|
file_path: filePath,
|
||||||
},
|
},
|
||||||
this.client,
|
this.client,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
correctedContent = correctedParams.new_string;
|
correctedContent = correctedParams.new_string;
|
||||||
} else {
|
} else {
|
||||||
|
@ -323,6 +328,7 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
correctedContent = await ensureCorrectFileContent(
|
correctedContent = await ensureCorrectFileContent(
|
||||||
proposedContent,
|
proposedContent,
|
||||||
this.client,
|
this.client,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return { originalContent, correctedContent, fileExists };
|
return { originalContent, correctedContent, fileExists };
|
||||||
|
|
|
@ -132,6 +132,7 @@ describe('editCorrector', () => {
|
||||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||||
let mockConfigInstance: Config;
|
let mockConfigInstance: Config;
|
||||||
|
const abortSignal = new AbortController().signal;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
|
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
|
||||||
|
@ -187,12 +188,18 @@ describe('editCorrector', () => {
|
||||||
|
|
||||||
callCount = 0;
|
callCount = 0;
|
||||||
mockResponses.length = 0;
|
mockResponses.length = 0;
|
||||||
mockGenerateJson = vi.fn().mockImplementation(() => {
|
mockGenerateJson = vi
|
||||||
const response = mockResponses[callCount];
|
.fn()
|
||||||
callCount++;
|
.mockImplementation((_contents, _schema, signal) => {
|
||||||
if (response === undefined) return Promise.resolve({});
|
// Check if the signal is aborted. If so, throw an error or return a specific response.
|
||||||
return Promise.resolve(response);
|
if (signal && signal.aborted) {
|
||||||
});
|
return Promise.reject(new Error('Aborted')); // Or some other specific error/response
|
||||||
|
}
|
||||||
|
const response = mockResponses[callCount];
|
||||||
|
callCount++;
|
||||||
|
if (response === undefined) return Promise.resolve({});
|
||||||
|
return Promise.resolve(response);
|
||||||
|
});
|
||||||
mockStartChat = vi.fn();
|
mockStartChat = vi.fn();
|
||||||
mockSendMessageStream = vi.fn();
|
mockSendMessageStream = vi.fn();
|
||||||
|
|
||||||
|
@ -217,6 +224,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||||
expect(result.params.new_string).toBe('replace with "this"');
|
expect(result.params.new_string).toBe('replace with "this"');
|
||||||
|
@ -234,6 +242,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||||
expect(result.params.new_string).toBe('replace with this');
|
expect(result.params.new_string).toBe('replace with this');
|
||||||
|
@ -254,6 +263,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||||
expect(result.params.new_string).toBe('replace with "this"');
|
expect(result.params.new_string).toBe('replace with "this"');
|
||||||
|
@ -271,6 +281,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||||
expect(result.params.new_string).toBe('replace with this');
|
expect(result.params.new_string).toBe('replace with this');
|
||||||
|
@ -292,6 +303,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||||
expect(result.params.new_string).toBe('replace with "this"');
|
expect(result.params.new_string).toBe('replace with "this"');
|
||||||
|
@ -309,6 +321,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||||
expect(result.params.new_string).toBe('replace with this');
|
expect(result.params.new_string).toBe('replace with this');
|
||||||
|
@ -329,6 +342,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||||
expect(result.params.new_string).toBe('replace with foobar');
|
expect(result.params.new_string).toBe('replace with foobar');
|
||||||
|
@ -351,6 +365,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||||
expect(result.params.new_string).toBe(llmNewString);
|
expect(result.params.new_string).toBe(llmNewString);
|
||||||
|
@ -372,6 +387,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||||
expect(result.params.new_string).toBe(llmNewString);
|
expect(result.params.new_string).toBe(llmNewString);
|
||||||
|
@ -391,6 +407,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||||
expect(result.params.new_string).toBe('replace with "this"');
|
expect(result.params.new_string).toBe('replace with "this"');
|
||||||
|
@ -412,6 +429,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||||
expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM);
|
expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM);
|
||||||
|
@ -432,6 +450,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||||
expect(result.params).toEqual(originalParams);
|
expect(result.params).toEqual(originalParams);
|
||||||
|
@ -449,6 +468,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||||
expect(result.params).toEqual(originalParams);
|
expect(result.params).toEqual(originalParams);
|
||||||
|
@ -471,6 +491,7 @@ describe('editCorrector', () => {
|
||||||
currentContent,
|
currentContent,
|
||||||
originalParams,
|
originalParams,
|
||||||
mockGeminiClientInstance,
|
mockGeminiClientInstance,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||||
expect(result.params.old_string).toBe(currentContent);
|
expect(result.params.old_string).toBe(currentContent);
|
||||||
|
|
|
@ -63,6 +63,7 @@ export async function ensureCorrectEdit(
|
||||||
currentContent: string,
|
currentContent: string,
|
||||||
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
|
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
|
||||||
client: GeminiClient,
|
client: GeminiClient,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<CorrectedEditResult> {
|
): Promise<CorrectedEditResult> {
|
||||||
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
|
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
|
||||||
const cachedResult = editCorrectionCache.get(cacheKey);
|
const cachedResult = editCorrectionCache.get(cacheKey);
|
||||||
|
@ -84,6 +85,7 @@ export async function ensureCorrectEdit(
|
||||||
client,
|
client,
|
||||||
finalOldString,
|
finalOldString,
|
||||||
originalParams.new_string,
|
originalParams.new_string,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else if (occurrences > 1) {
|
} else if (occurrences > 1) {
|
||||||
|
@ -108,6 +110,7 @@ export async function ensureCorrectEdit(
|
||||||
originalParams.old_string, // original old
|
originalParams.old_string, // original old
|
||||||
unescapedOldStringAttempt, // corrected old
|
unescapedOldStringAttempt, // corrected old
|
||||||
originalParams.new_string, // original new (which is potentially escaped)
|
originalParams.new_string, // original new (which is potentially escaped)
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else if (occurrences === 0) {
|
} else if (occurrences === 0) {
|
||||||
|
@ -115,6 +118,7 @@ export async function ensureCorrectEdit(
|
||||||
client,
|
client,
|
||||||
currentContent,
|
currentContent,
|
||||||
unescapedOldStringAttempt,
|
unescapedOldStringAttempt,
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
const llmOldOccurrences = countOccurrences(
|
const llmOldOccurrences = countOccurrences(
|
||||||
currentContent,
|
currentContent,
|
||||||
|
@ -134,6 +138,7 @@ export async function ensureCorrectEdit(
|
||||||
originalParams.old_string, // original old
|
originalParams.old_string, // original old
|
||||||
llmCorrectedOldString, // corrected old
|
llmCorrectedOldString, // corrected old
|
||||||
baseNewStringForLLMCorrection, // base new for correction
|
baseNewStringForLLMCorrection, // base new for correction
|
||||||
|
abortSignal,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -180,6 +185,7 @@ export async function ensureCorrectEdit(
|
||||||
export async function ensureCorrectFileContent(
|
export async function ensureCorrectFileContent(
|
||||||
content: string,
|
content: string,
|
||||||
client: GeminiClient,
|
client: GeminiClient,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
const cachedResult = fileContentCorrectionCache.get(content);
|
const cachedResult = fileContentCorrectionCache.get(content);
|
||||||
if (cachedResult) {
|
if (cachedResult) {
|
||||||
|
@ -193,7 +199,11 @@ export async function ensureCorrectFileContent(
|
||||||
return content;
|
return content;
|
||||||
}
|
}
|
||||||
|
|
||||||
const correctedContent = await correctStringEscaping(content, client);
|
const correctedContent = await correctStringEscaping(
|
||||||
|
content,
|
||||||
|
client,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
fileContentCorrectionCache.set(content, correctedContent);
|
fileContentCorrectionCache.set(content, correctedContent);
|
||||||
return correctedContent;
|
return correctedContent;
|
||||||
}
|
}
|
||||||
|
@ -215,6 +225,7 @@ export async function correctOldStringMismatch(
|
||||||
geminiClient: GeminiClient,
|
geminiClient: GeminiClient,
|
||||||
fileContent: string,
|
fileContent: string,
|
||||||
problematicSnippet: string,
|
problematicSnippet: string,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
const prompt = `
|
const prompt = `
|
||||||
Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped.
|
Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped.
|
||||||
|
@ -243,6 +254,7 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
|
||||||
const result = await geminiClient.generateJson(
|
const result = await geminiClient.generateJson(
|
||||||
contents,
|
contents,
|
||||||
OLD_STRING_CORRECTION_SCHEMA,
|
OLD_STRING_CORRECTION_SCHEMA,
|
||||||
|
abortSignal,
|
||||||
EditModel,
|
EditModel,
|
||||||
EditConfig,
|
EditConfig,
|
||||||
);
|
);
|
||||||
|
@ -257,10 +269,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
|
||||||
return problematicSnippet;
|
return problematicSnippet;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (abortSignal.aborted) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
console.error(
|
console.error(
|
||||||
'Error during LLM call for old string snippet correction:',
|
'Error during LLM call for old string snippet correction:',
|
||||||
error,
|
error,
|
||||||
);
|
);
|
||||||
|
|
||||||
return problematicSnippet;
|
return problematicSnippet;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -286,6 +303,7 @@ export async function correctNewString(
|
||||||
originalOldString: string,
|
originalOldString: string,
|
||||||
correctedOldString: string,
|
correctedOldString: string,
|
||||||
originalNewString: string,
|
originalNewString: string,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
if (originalOldString === correctedOldString) {
|
if (originalOldString === correctedOldString) {
|
||||||
return originalNewString;
|
return originalNewString;
|
||||||
|
@ -324,6 +342,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||||
const result = await geminiClient.generateJson(
|
const result = await geminiClient.generateJson(
|
||||||
contents,
|
contents,
|
||||||
NEW_STRING_CORRECTION_SCHEMA,
|
NEW_STRING_CORRECTION_SCHEMA,
|
||||||
|
abortSignal,
|
||||||
EditModel,
|
EditModel,
|
||||||
EditConfig,
|
EditConfig,
|
||||||
);
|
);
|
||||||
|
@ -338,6 +357,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||||
return originalNewString;
|
return originalNewString;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (abortSignal.aborted) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
console.error('Error during LLM call for new_string correction:', error);
|
console.error('Error during LLM call for new_string correction:', error);
|
||||||
return originalNewString;
|
return originalNewString;
|
||||||
}
|
}
|
||||||
|
@ -359,6 +382,7 @@ export async function correctNewStringEscaping(
|
||||||
geminiClient: GeminiClient,
|
geminiClient: GeminiClient,
|
||||||
oldString: string,
|
oldString: string,
|
||||||
potentiallyProblematicNewString: string,
|
potentiallyProblematicNewString: string,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
const prompt = `
|
const prompt = `
|
||||||
Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||||
|
@ -387,6 +411,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||||
const result = await geminiClient.generateJson(
|
const result = await geminiClient.generateJson(
|
||||||
contents,
|
contents,
|
||||||
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||||
|
abortSignal,
|
||||||
EditModel,
|
EditModel,
|
||||||
EditConfig,
|
EditConfig,
|
||||||
);
|
);
|
||||||
|
@ -401,6 +426,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||||
return potentiallyProblematicNewString;
|
return potentiallyProblematicNewString;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (abortSignal.aborted) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
console.error(
|
console.error(
|
||||||
'Error during LLM call for new_string escaping correction:',
|
'Error during LLM call for new_string escaping correction:',
|
||||||
error,
|
error,
|
||||||
|
@ -424,6 +453,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = {
|
||||||
export async function correctStringEscaping(
|
export async function correctStringEscaping(
|
||||||
potentiallyProblematicString: string,
|
potentiallyProblematicString: string,
|
||||||
client: GeminiClient,
|
client: GeminiClient,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<string> {
|
): Promise<string> {
|
||||||
const prompt = `
|
const prompt = `
|
||||||
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||||
|
@ -447,6 +477,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||||
const result = await client.generateJson(
|
const result = await client.generateJson(
|
||||||
contents,
|
contents,
|
||||||
CORRECT_STRING_ESCAPING_SCHEMA,
|
CORRECT_STRING_ESCAPING_SCHEMA,
|
||||||
|
abortSignal,
|
||||||
EditModel,
|
EditModel,
|
||||||
EditConfig,
|
EditConfig,
|
||||||
);
|
);
|
||||||
|
@ -461,6 +492,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||||
return potentiallyProblematicString;
|
return potentiallyProblematicString;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (abortSignal.aborted) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
console.error(
|
console.error(
|
||||||
'Error during LLM call for string escaping correction:',
|
'Error during LLM call for string escaping correction:',
|
||||||
error,
|
error,
|
||||||
|
|
|
@ -44,6 +44,7 @@ describe('checkNextSpeaker', () => {
|
||||||
let chatInstance: GeminiChat;
|
let chatInstance: GeminiChat;
|
||||||
let mockGeminiClient: GeminiClient;
|
let mockGeminiClient: GeminiClient;
|
||||||
let MockConfig: Mock;
|
let MockConfig: Mock;
|
||||||
|
const abortSignal = new AbortController().signal;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
MockConfig = vi.mocked(Config);
|
MockConfig = vi.mocked(Config);
|
||||||
|
@ -71,7 +72,7 @@ describe('checkNextSpeaker', () => {
|
||||||
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
|
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
|
||||||
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
||||||
'gemini-pro', // model name
|
'gemini-pro', // model name
|
||||||
{}, // config
|
{},
|
||||||
[], // initial history
|
[], // initial history
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -85,7 +86,11 @@ describe('checkNextSpeaker', () => {
|
||||||
|
|
||||||
it('should return null if history is empty', async () => {
|
it('should return null if history is empty', async () => {
|
||||||
(chatInstance.getHistory as Mock).mockReturnValue([]);
|
(chatInstance.getHistory as Mock).mockReturnValue([]);
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
@ -94,7 +99,11 @@ describe('checkNextSpeaker', () => {
|
||||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||||
] as Content[]);
|
] as Content[]);
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
@ -109,7 +118,11 @@ describe('checkNextSpeaker', () => {
|
||||||
};
|
};
|
||||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||||
|
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toEqual(mockApiResponse);
|
expect(result).toEqual(mockApiResponse);
|
||||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||||
});
|
});
|
||||||
|
@ -124,7 +137,11 @@ describe('checkNextSpeaker', () => {
|
||||||
};
|
};
|
||||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||||
|
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toEqual(mockApiResponse);
|
expect(result).toEqual(mockApiResponse);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -138,7 +155,11 @@ describe('checkNextSpeaker', () => {
|
||||||
};
|
};
|
||||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||||
|
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toEqual(mockApiResponse);
|
expect(result).toEqual(mockApiResponse);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -153,7 +174,11 @@ describe('checkNextSpeaker', () => {
|
||||||
new Error('API Error'),
|
new Error('API Error'),
|
||||||
);
|
);
|
||||||
|
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
consoleWarnSpy.mockRestore();
|
consoleWarnSpy.mockRestore();
|
||||||
});
|
});
|
||||||
|
@ -166,7 +191,11 @@ describe('checkNextSpeaker', () => {
|
||||||
reasoning: 'This is incomplete.',
|
reasoning: 'This is incomplete.',
|
||||||
} as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response
|
} as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response
|
||||||
|
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -179,7 +208,11 @@ describe('checkNextSpeaker', () => {
|
||||||
next_speaker: 123, // Invalid type
|
next_speaker: 123, // Invalid type
|
||||||
} as unknown as NextSpeakerResponse);
|
} as unknown as NextSpeakerResponse);
|
||||||
|
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -192,7 +225,11 @@ describe('checkNextSpeaker', () => {
|
||||||
next_speaker: 'neither', // Invalid enum value
|
next_speaker: 'neither', // Invalid enum value
|
||||||
} as unknown as NextSpeakerResponse);
|
} as unknown as NextSpeakerResponse);
|
||||||
|
|
||||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
const result = await checkNextSpeaker(
|
||||||
|
chatInstance,
|
||||||
|
mockGeminiClient,
|
||||||
|
abortSignal,
|
||||||
|
);
|
||||||
expect(result).toBeNull();
|
expect(result).toBeNull();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -61,6 +61,7 @@ export interface NextSpeakerResponse {
|
||||||
export async function checkNextSpeaker(
|
export async function checkNextSpeaker(
|
||||||
chat: GeminiChat,
|
chat: GeminiChat,
|
||||||
geminiClient: GeminiClient,
|
geminiClient: GeminiClient,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
): Promise<NextSpeakerResponse | null> {
|
): Promise<NextSpeakerResponse | null> {
|
||||||
// We need to capture the curated history because there are many moments when the model will return invalid turns
|
// We need to capture the curated history because there are many moments when the model will return invalid turns
|
||||||
// that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides
|
// that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides
|
||||||
|
@ -129,6 +130,7 @@ export async function checkNextSpeaker(
|
||||||
const parsedResponse = (await geminiClient.generateJson(
|
const parsedResponse = (await geminiClient.generateJson(
|
||||||
contents,
|
contents,
|
||||||
RESPONSE_SCHEMA,
|
RESPONSE_SCHEMA,
|
||||||
|
abortSignal,
|
||||||
)) as unknown as NextSpeakerResponse;
|
)) as unknown as NextSpeakerResponse;
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Reference in New Issue