Fix(server): Ensure debug responses are not recorded after cancellation (#491)

This commit is contained in:
Allen Hutchison 2025-05-22 16:34:32 -07:00 committed by GitHub
parent 6d3af7b97f
commit 1d0856dcc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 271 additions and 1 deletions

View File

@ -0,0 +1,269 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
Turn,
GeminiEventType,
ServerGeminiToolCallRequestEvent,
ServerGeminiErrorEvent,
} from './turn.js';
import { Chat, GenerateContentResponse, Part, Content } from '@google/genai';
import { reportError } from '../utils/errorReporting.js';
const mockSendMessageStream = vi.fn();
const mockGetHistory = vi.fn();
vi.mock('@google/genai', async (importOriginal) => {
const actual = await importOriginal<typeof import('@google/genai')>();
const MockChat = vi.fn().mockImplementation(() => ({
sendMessageStream: mockSendMessageStream,
getHistory: mockGetHistory,
}));
return {
...actual,
Chat: MockChat,
};
});
vi.mock('../utils/errorReporting', () => ({
reportError: vi.fn(),
}));
vi.mock('../utils/generateContentResponseUtilities', () => ({
getResponseText: (resp: GenerateContentResponse) =>
resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
undefined,
}));
describe('Turn', () => {
let turn: Turn;
// Define a type for the mocked Chat instance for clarity
type MockedChatInstance = {
sendMessageStream: typeof mockSendMessageStream;
getHistory: typeof mockGetHistory;
};
let mockChatInstance: MockedChatInstance;
beforeEach(() => {
vi.resetAllMocks();
mockChatInstance = {
sendMessageStream: mockSendMessageStream,
getHistory: mockGetHistory,
};
turn = new Turn(mockChatInstance as unknown as Chat);
mockGetHistory.mockReturnValue([]);
mockSendMessageStream.mockResolvedValue((async function* () {})());
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('constructor', () => {
it('should initialize pendingToolCalls and debugResponses', () => {
expect(turn.pendingToolCalls).toEqual([]);
expect(turn.getDebugResponses()).toEqual([]);
});
});
describe('run', () => {
it('should yield content events for text parts', async () => {
const mockResponseStream = (async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
} as unknown as GenerateContentResponse;
yield {
candidates: [{ content: { parts: [{ text: ' world' }] } }],
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
const reqParts: Part[] = [{ text: 'Hi' }];
for await (const event of turn.run(reqParts)) {
events.push(event);
}
expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts });
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' world' },
]);
expect(turn.getDebugResponses().length).toBe(2);
});
it('should yield tool_call_request events for function calls', async () => {
const mockResponseStream = (async function* () {
yield {
functionCalls: [
{ id: 'fc1', name: 'tool1', args: { arg1: 'val1' } },
{ name: 'tool2', args: { arg2: 'val2' } }, // No ID
],
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
const reqParts: Part[] = [{ text: 'Use tools' }];
for await (const event of turn.run(reqParts)) {
events.push(event);
}
expect(events.length).toBe(2);
const event1 = events[0] as ServerGeminiToolCallRequestEvent;
expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
expect(event1.value).toEqual(
expect.objectContaining({
callId: 'fc1',
name: 'tool1',
args: { arg1: 'val1' },
}),
);
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
expect(event2.value).toEqual(
expect.objectContaining({ name: 'tool2', args: { arg2: 'val2' } }),
);
expect(event2.value.callId).toEqual(
expect.stringMatching(/^tool2-\d{13}-\w{10,}$/),
);
expect(turn.pendingToolCalls[1]).toEqual(event2.value);
expect(turn.getDebugResponses().length).toBe(1);
});
it('should yield UserCancelled event if signal is aborted', async () => {
const abortController = new AbortController();
const mockResponseStream = (async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'First part' }] } }],
} as unknown as GenerateContentResponse;
abortController.abort();
yield {
candidates: [
{
content: {
parts: [{ text: 'Second part - should not be processed' }],
},
},
],
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
const reqParts: Part[] = [{ text: 'Test abort' }];
for await (const event of turn.run(reqParts, abortController.signal)) {
events.push(event);
}
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'First part' },
{ type: GeminiEventType.UserCancelled },
]);
expect(turn.getDebugResponses().length).toBe(1);
});
it('should yield Error event and report if sendMessageStream throws', async () => {
const error = new Error('API Error');
mockSendMessageStream.mockRejectedValue(error);
const reqParts: Part[] = [{ text: 'Trigger error' }];
const historyContent: Content[] = [
{ role: 'model', parts: [{ text: 'Previous history' }] },
];
mockGetHistory.mockReturnValue(historyContent);
const events = [];
for await (const event of turn.run(reqParts)) {
events.push(event);
}
expect(events.length).toBe(1);
const errorEvent = events[0] as ServerGeminiErrorEvent;
expect(errorEvent.type).toBe(GeminiEventType.Error);
expect(errorEvent.value).toEqual({ message: 'API Error' });
expect(turn.getDebugResponses().length).toBe(0);
expect(reportError).toHaveBeenCalledWith(
error,
'Error when talking to Gemini API',
[...historyContent, reqParts],
'Turn.run-sendMessageStream',
);
});
it('should handle function calls with undefined name or args', async () => {
const mockResponseStream = (async function* () {
yield {
functionCalls: [
{ id: 'fc1', name: undefined, args: { arg1: 'val1' } },
{ id: 'fc2', name: 'tool2', args: undefined },
{ id: 'fc3', name: undefined, args: undefined },
],
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
for await (const event of turn.run(reqParts)) {
events.push(event);
}
expect(events.length).toBe(3);
const event1 = events[0] as ServerGeminiToolCallRequestEvent;
expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
expect(event1.value).toEqual(
expect.objectContaining({
callId: 'fc1',
name: 'undefined_tool_name',
args: { arg1: 'val1' },
}),
);
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
expect(event2.value).toEqual(
expect.objectContaining({ callId: 'fc2', name: 'tool2', args: {} }),
);
expect(turn.pendingToolCalls[1]).toEqual(event2.value);
const event3 = events[2] as ServerGeminiToolCallRequestEvent;
expect(event3.type).toBe(GeminiEventType.ToolCallRequest);
expect(event3.value).toEqual(
expect.objectContaining({
callId: 'fc3',
name: 'undefined_tool_name',
args: {},
}),
);
expect(turn.pendingToolCalls[2]).toEqual(event3.value);
expect(turn.getDebugResponses().length).toBe(1);
});
});
describe('getDebugResponses', () => {
it('should return collected debug responses', async () => {
const resp1 = {
candidates: [{ content: { parts: [{ text: 'Debug 1' }] } }],
} as unknown as GenerateContentResponse;
const resp2 = {
functionCalls: [{ name: 'debugTool' }],
} as unknown as GenerateContentResponse;
const mockResponseStream = (async function* () {
yield resp1;
yield resp2;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const reqParts: Part[] = [{ text: 'Hi' }];
for await (const _ of turn.run(reqParts)) {
// consume stream
}
expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
});
});
});

View File

@ -128,11 +128,12 @@ export class Turn {
});
for await (const resp of responseStream) {
this.debugResponses.push(resp);
if (signal?.aborted) {
yield { type: GeminiEventType.UserCancelled };
// Do not add resp to debugResponses if aborted before processing
return;
}
this.debugResponses.push(resp);
const text = getResponseText(resp);
if (text) {