Fix(server): Ensure debug responses are not recorded after cancellation (#491)
This commit is contained in:
parent
6d3af7b97f
commit
1d0856dcc8
|
@ -0,0 +1,269 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
Turn,
|
||||
GeminiEventType,
|
||||
ServerGeminiToolCallRequestEvent,
|
||||
ServerGeminiErrorEvent,
|
||||
} from './turn.js';
|
||||
import { Chat, GenerateContentResponse, Part, Content } from '@google/genai';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
|
||||
const mockSendMessageStream = vi.fn();
|
||||
const mockGetHistory = vi.fn();
|
||||
|
||||
vi.mock('@google/genai', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@google/genai')>();
|
||||
const MockChat = vi.fn().mockImplementation(() => ({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: mockGetHistory,
|
||||
}));
|
||||
return {
|
||||
...actual,
|
||||
Chat: MockChat,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../utils/errorReporting', () => ({
|
||||
reportError: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/generateContentResponseUtilities', () => ({
|
||||
getResponseText: (resp: GenerateContentResponse) =>
|
||||
resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
|
||||
undefined,
|
||||
}));
|
||||
|
||||
describe('Turn', () => {
|
||||
let turn: Turn;
|
||||
// Define a type for the mocked Chat instance for clarity
|
||||
type MockedChatInstance = {
|
||||
sendMessageStream: typeof mockSendMessageStream;
|
||||
getHistory: typeof mockGetHistory;
|
||||
};
|
||||
let mockChatInstance: MockedChatInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockChatInstance = {
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: mockGetHistory,
|
||||
};
|
||||
turn = new Turn(mockChatInstance as unknown as Chat);
|
||||
mockGetHistory.mockReturnValue([]);
|
||||
mockSendMessageStream.mockResolvedValue((async function* () {})());
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize pendingToolCalls and debugResponses', () => {
|
||||
expect(turn.pendingToolCalls).toEqual([]);
|
||||
expect(turn.getDebugResponses()).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('run', () => {
|
||||
it('should yield content events for text parts', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: ' world' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
for await (const event of turn.run(reqParts)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts });
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Hello' },
|
||||
{ type: GeminiEventType.Content, value: ' world' },
|
||||
]);
|
||||
expect(turn.getDebugResponses().length).toBe(2);
|
||||
});
|
||||
|
||||
it('should yield tool_call_request events for function calls', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
functionCalls: [
|
||||
{ id: 'fc1', name: 'tool1', args: { arg1: 'val1' } },
|
||||
{ name: 'tool2', args: { arg2: 'val2' } }, // No ID
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Use tools' }];
|
||||
for await (const event of turn.run(reqParts)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events.length).toBe(2);
|
||||
const event1 = events[0] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event1.value).toEqual(
|
||||
expect.objectContaining({
|
||||
callId: 'fc1',
|
||||
name: 'tool1',
|
||||
args: { arg1: 'val1' },
|
||||
}),
|
||||
);
|
||||
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
|
||||
|
||||
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event2.value).toEqual(
|
||||
expect.objectContaining({ name: 'tool2', args: { arg2: 'val2' } }),
|
||||
);
|
||||
expect(event2.value.callId).toEqual(
|
||||
expect.stringMatching(/^tool2-\d{13}-\w{10,}$/),
|
||||
);
|
||||
expect(turn.pendingToolCalls[1]).toEqual(event2.value);
|
||||
expect(turn.getDebugResponses().length).toBe(1);
|
||||
});
|
||||
|
||||
it('should yield UserCancelled event if signal is aborted', async () => {
|
||||
const abortController = new AbortController();
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: 'First part' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
abortController.abort();
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Second part - should not be processed' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test abort' }];
|
||||
for await (const event of turn.run(reqParts, abortController.signal)) {
|
||||
events.push(event);
|
||||
}
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'First part' },
|
||||
{ type: GeminiEventType.UserCancelled },
|
||||
]);
|
||||
expect(turn.getDebugResponses().length).toBe(1);
|
||||
});
|
||||
|
||||
it('should yield Error event and report if sendMessageStream throws', async () => {
|
||||
const error = new Error('API Error');
|
||||
mockSendMessageStream.mockRejectedValue(error);
|
||||
const reqParts: Part[] = [{ text: 'Trigger error' }];
|
||||
const historyContent: Content[] = [
|
||||
{ role: 'model', parts: [{ text: 'Previous history' }] },
|
||||
];
|
||||
mockGetHistory.mockReturnValue(historyContent);
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(reqParts)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events.length).toBe(1);
|
||||
const errorEvent = events[0] as ServerGeminiErrorEvent;
|
||||
expect(errorEvent.type).toBe(GeminiEventType.Error);
|
||||
expect(errorEvent.value).toEqual({ message: 'API Error' });
|
||||
expect(turn.getDebugResponses().length).toBe(0);
|
||||
expect(reportError).toHaveBeenCalledWith(
|
||||
error,
|
||||
'Error when talking to Gemini API',
|
||||
[...historyContent, reqParts],
|
||||
'Turn.run-sendMessageStream',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle function calls with undefined name or args', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
functionCalls: [
|
||||
{ id: 'fc1', name: undefined, args: { arg1: 'val1' } },
|
||||
{ id: 'fc2', name: 'tool2', args: undefined },
|
||||
{ id: 'fc3', name: undefined, args: undefined },
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
|
||||
for await (const event of turn.run(reqParts)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events.length).toBe(3);
|
||||
const event1 = events[0] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event1.value).toEqual(
|
||||
expect.objectContaining({
|
||||
callId: 'fc1',
|
||||
name: 'undefined_tool_name',
|
||||
args: { arg1: 'val1' },
|
||||
}),
|
||||
);
|
||||
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
|
||||
|
||||
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event2.value).toEqual(
|
||||
expect.objectContaining({ callId: 'fc2', name: 'tool2', args: {} }),
|
||||
);
|
||||
expect(turn.pendingToolCalls[1]).toEqual(event2.value);
|
||||
|
||||
const event3 = events[2] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event3.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event3.value).toEqual(
|
||||
expect.objectContaining({
|
||||
callId: 'fc3',
|
||||
name: 'undefined_tool_name',
|
||||
args: {},
|
||||
}),
|
||||
);
|
||||
expect(turn.pendingToolCalls[2]).toEqual(event3.value);
|
||||
expect(turn.getDebugResponses().length).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDebugResponses', () => {
|
||||
it('should return collected debug responses', async () => {
|
||||
const resp1 = {
|
||||
candidates: [{ content: { parts: [{ text: 'Debug 1' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
const resp2 = {
|
||||
functionCalls: [{ name: 'debugTool' }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
const mockResponseStream = (async function* () {
|
||||
yield resp1;
|
||||
yield resp2;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
for await (const _ of turn.run(reqParts)) {
|
||||
// consume stream
|
||||
}
|
||||
expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -128,11 +128,12 @@ export class Turn {
|
|||
});
|
||||
|
||||
for await (const resp of responseStream) {
|
||||
this.debugResponses.push(resp);
|
||||
if (signal?.aborted) {
|
||||
yield { type: GeminiEventType.UserCancelled };
|
||||
// Do not add resp to debugResponses if aborted before processing
|
||||
return;
|
||||
}
|
||||
this.debugResponses.push(resp);
|
||||
|
||||
const text = getResponseText(resp);
|
||||
if (text) {
|
||||
|
|
Loading…
Reference in New Issue