Don't start uncompressed history with a function response (#4141)
This commit is contained in:
parent
c313c3dee1
commit
2f1d6234de
|
@ -470,34 +470,31 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
describe('tryCompressChat', () => {
|
describe('tryCompressChat', () => {
|
||||||
const mockCountTokens = vi.fn();
|
const mockCountTokens = vi.fn();
|
||||||
const mockSendMessage = vi.fn();
|
const mockSendMessage = vi.fn();
|
||||||
|
const mockGetHistory = vi.fn();
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.mock('./tokenLimits', () => ({
|
vi.mock('./tokenLimits', () => ({
|
||||||
tokenLimit: vi.fn(),
|
tokenLimit: vi.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const mockGenerator: Partial<ContentGenerator> = {
|
client['contentGenerator'] = {
|
||||||
countTokens: mockCountTokens,
|
countTokens: mockCountTokens,
|
||||||
};
|
} as unknown as ContentGenerator;
|
||||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
|
||||||
|
|
||||||
// Mock the chat's sendMessage method
|
client['chat'] = {
|
||||||
const mockChat: Partial<GeminiChat> = {
|
getHistory: mockGetHistory,
|
||||||
getHistory: vi
|
|
||||||
.fn()
|
|
||||||
.mockReturnValue([
|
|
||||||
{ role: 'user', parts: [{ text: '...history...' }] },
|
|
||||||
]),
|
|
||||||
addHistory: vi.fn(),
|
addHistory: vi.fn(),
|
||||||
setHistory: vi.fn(),
|
setHistory: vi.fn(),
|
||||||
sendMessage: mockSendMessage,
|
sendMessage: mockSendMessage,
|
||||||
};
|
} as unknown as GeminiChat;
|
||||||
client['chat'] = mockChat as GeminiChat;
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should not trigger summarization if token count is below threshold', async () => {
|
it('should not trigger summarization if token count is below threshold', async () => {
|
||||||
const MOCKED_TOKEN_LIMIT = 1000;
|
const MOCKED_TOKEN_LIMIT = 1000;
|
||||||
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||||
|
mockGetHistory.mockReturnValue([
|
||||||
|
{ role: 'user', parts: [{ text: '...history...' }] },
|
||||||
|
]);
|
||||||
|
|
||||||
mockCountTokens.mockResolvedValue({
|
mockCountTokens.mockResolvedValue({
|
||||||
totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7
|
totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7
|
||||||
|
@ -515,6 +512,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
it('should trigger summarization if token count is at threshold', async () => {
|
it('should trigger summarization if token count is at threshold', async () => {
|
||||||
const MOCKED_TOKEN_LIMIT = 1000;
|
const MOCKED_TOKEN_LIMIT = 1000;
|
||||||
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||||
|
mockGetHistory.mockReturnValue([
|
||||||
|
{ role: 'user', parts: [{ text: '...history...' }] },
|
||||||
|
]);
|
||||||
|
|
||||||
const originalTokenCount = 1000 * 0.7;
|
const originalTokenCount = 1000 * 0.7;
|
||||||
const newTokenCount = 100;
|
const newTokenCount = 100;
|
||||||
|
@ -546,7 +546,69 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
expect(newChat).not.toBe(initialChat);
|
expect(newChat).not.toBe(initialChat);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should not compress across a function call response', async () => {
|
||||||
|
const MOCKED_TOKEN_LIMIT = 1000;
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||||
|
mockGetHistory.mockReturnValue([
|
||||||
|
{ role: 'user', parts: [{ text: '...history 1...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 2...' }] },
|
||||||
|
{ role: 'user', parts: [{ text: '...history 3...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 4...' }] },
|
||||||
|
{ role: 'user', parts: [{ text: '...history 5...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 6...' }] },
|
||||||
|
{ role: 'user', parts: [{ text: '...history 7...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 8...' }] },
|
||||||
|
// Normally we would break here, but we have a function response.
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ functionResponse: { name: '...history 8...' } }],
|
||||||
|
},
|
||||||
|
{ role: 'model', parts: [{ text: '...history 10...' }] },
|
||||||
|
// Instead we will break here.
|
||||||
|
{ role: 'user', parts: [{ text: '...history 10...' }] },
|
||||||
|
]);
|
||||||
|
|
||||||
|
const originalTokenCount = 1000 * 0.7;
|
||||||
|
const newTokenCount = 100;
|
||||||
|
|
||||||
|
mockCountTokens
|
||||||
|
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
|
||||||
|
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
|
||||||
|
|
||||||
|
// Mock the summary response from the chat
|
||||||
|
mockSendMessage.mockResolvedValue({
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'This is a summary.' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const initialChat = client.getChat();
|
||||||
|
const result = await client.tryCompressChat('prompt-id-3');
|
||||||
|
const newChat = client.getChat();
|
||||||
|
|
||||||
|
expect(tokenLimit).toHaveBeenCalled();
|
||||||
|
expect(mockSendMessage).toHaveBeenCalled();
|
||||||
|
|
||||||
|
// Assert that summarization happened and returned the correct stats
|
||||||
|
expect(result).toEqual({
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount,
|
||||||
|
});
|
||||||
|
// Assert that the chat was reset
|
||||||
|
expect(newChat).not.toBe(initialChat);
|
||||||
|
|
||||||
|
// 1. standard start context message
|
||||||
|
// 2. standard canned user start message
|
||||||
|
// 3. compressed summary message
|
||||||
|
// 4. standard canned user summary message
|
||||||
|
// 5. The last user message (not the last 3 because that would start with a function response)
|
||||||
|
expect(newChat.getHistory().length).toEqual(5);
|
||||||
|
});
|
||||||
|
|
||||||
it('should always trigger summarization when force is true, regardless of token count', async () => {
|
it('should always trigger summarization when force is true, regardless of token count', async () => {
|
||||||
|
mockGetHistory.mockReturnValue([
|
||||||
|
{ role: 'user', parts: [{ text: '...history...' }] },
|
||||||
|
]);
|
||||||
|
|
||||||
const originalTokenCount = 10; // Well below threshold
|
const originalTokenCount = 10; // Well below threshold
|
||||||
const newTokenCount = 5;
|
const newTokenCount = 5;
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ import { reportError } from '../utils/errorReporting.js';
|
||||||
import { GeminiChat } from './geminiChat.js';
|
import { GeminiChat } from './geminiChat.js';
|
||||||
import { retryWithBackoff } from '../utils/retry.js';
|
import { retryWithBackoff } from '../utils/retry.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
|
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||||
import { tokenLimit } from './tokenLimits.js';
|
import { tokenLimit } from './tokenLimits.js';
|
||||||
import {
|
import {
|
||||||
AuthType,
|
AuthType,
|
||||||
|
@ -547,7 +548,8 @@ export class GeminiClient {
|
||||||
// Find the first user message after the index. This is the start of the next turn.
|
// Find the first user message after the index. This is the start of the next turn.
|
||||||
while (
|
while (
|
||||||
compressBeforeIndex < curatedHistory.length &&
|
compressBeforeIndex < curatedHistory.length &&
|
||||||
curatedHistory[compressBeforeIndex]?.role !== 'user'
|
(curatedHistory[compressBeforeIndex]?.role === 'model' ||
|
||||||
|
isFunctionResponse(curatedHistory[compressBeforeIndex]))
|
||||||
) {
|
) {
|
||||||
compressBeforeIndex++;
|
compressBeforeIndex++;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue