Don't start uncompressed history with a function response (#4141)

This commit is contained in:
Tommaso Sciortino 2025-07-14 10:09:11 -07:00 committed by GitHub
parent c313c3dee1
commit 2f1d6234de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 77 additions and 13 deletions

View File

@ -470,34 +470,31 @@ describe('Gemini Client (client.ts)', () => {
describe('tryCompressChat', () => { describe('tryCompressChat', () => {
const mockCountTokens = vi.fn(); const mockCountTokens = vi.fn();
const mockSendMessage = vi.fn(); const mockSendMessage = vi.fn();
const mockGetHistory = vi.fn();
beforeEach(() => { beforeEach(() => {
vi.mock('./tokenLimits', () => ({ vi.mock('./tokenLimits', () => ({
tokenLimit: vi.fn(), tokenLimit: vi.fn(),
})); }));
const mockGenerator: Partial<ContentGenerator> = { client['contentGenerator'] = {
countTokens: mockCountTokens, countTokens: mockCountTokens,
}; } as unknown as ContentGenerator;
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Mock the chat's sendMessage method client['chat'] = {
const mockChat: Partial<GeminiChat> = { getHistory: mockGetHistory,
getHistory: vi
.fn()
.mockReturnValue([
{ role: 'user', parts: [{ text: '...history...' }] },
]),
addHistory: vi.fn(), addHistory: vi.fn(),
setHistory: vi.fn(), setHistory: vi.fn(),
sendMessage: mockSendMessage, sendMessage: mockSendMessage,
}; } as unknown as GeminiChat;
client['chat'] = mockChat as GeminiChat;
}); });
it('should not trigger summarization if token count is below threshold', async () => { it('should not trigger summarization if token count is below threshold', async () => {
const MOCKED_TOKEN_LIMIT = 1000; const MOCKED_TOKEN_LIMIT = 1000;
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT); vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
mockGetHistory.mockReturnValue([
{ role: 'user', parts: [{ text: '...history...' }] },
]);
mockCountTokens.mockResolvedValue({ mockCountTokens.mockResolvedValue({
totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7 totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7
@ -515,6 +512,9 @@ describe('Gemini Client (client.ts)', () => {
it('should trigger summarization if token count is at threshold', async () => { it('should trigger summarization if token count is at threshold', async () => {
const MOCKED_TOKEN_LIMIT = 1000; const MOCKED_TOKEN_LIMIT = 1000;
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT); vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
mockGetHistory.mockReturnValue([
{ role: 'user', parts: [{ text: '...history...' }] },
]);
const originalTokenCount = 1000 * 0.7; const originalTokenCount = 1000 * 0.7;
const newTokenCount = 100; const newTokenCount = 100;
@ -546,7 +546,69 @@ describe('Gemini Client (client.ts)', () => {
expect(newChat).not.toBe(initialChat); expect(newChat).not.toBe(initialChat);
}); });
it('should not compress across a function call response', async () => {
const MOCKED_TOKEN_LIMIT = 1000;
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
mockGetHistory.mockReturnValue([
{ role: 'user', parts: [{ text: '...history 1...' }] },
{ role: 'model', parts: [{ text: '...history 2...' }] },
{ role: 'user', parts: [{ text: '...history 3...' }] },
{ role: 'model', parts: [{ text: '...history 4...' }] },
{ role: 'user', parts: [{ text: '...history 5...' }] },
{ role: 'model', parts: [{ text: '...history 6...' }] },
{ role: 'user', parts: [{ text: '...history 7...' }] },
{ role: 'model', parts: [{ text: '...history 8...' }] },
// Normally we would break here, but we have a function response.
{
role: 'user',
parts: [{ functionResponse: { name: '...history 8...' } }],
},
{ role: 'model', parts: [{ text: '...history 10...' }] },
// Instead we will break here.
{ role: 'user', parts: [{ text: '...history 10...' }] },
]);
const originalTokenCount = 1000 * 0.7;
const newTokenCount = 100;
mockCountTokens
.mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check
.mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history
// Mock the summary response from the chat
mockSendMessage.mockResolvedValue({
role: 'model',
parts: [{ text: 'This is a summary.' }],
});
const initialChat = client.getChat();
const result = await client.tryCompressChat('prompt-id-3');
const newChat = client.getChat();
expect(tokenLimit).toHaveBeenCalled();
expect(mockSendMessage).toHaveBeenCalled();
// Assert that summarization happened and returned the correct stats
expect(result).toEqual({
originalTokenCount,
newTokenCount,
});
// Assert that the chat was reset
expect(newChat).not.toBe(initialChat);
// 1. standard start context message
// 2. standard canned user start message
// 3. compressed summary message
// 4. standard canned user summary message
// 5. The last user message (not the last 3 because that would start with a function response)
expect(newChat.getHistory().length).toEqual(5);
});
it('should always trigger summarization when force is true, regardless of token count', async () => { it('should always trigger summarization when force is true, regardless of token count', async () => {
mockGetHistory.mockReturnValue([
{ role: 'user', parts: [{ text: '...history...' }] },
]);
const originalTokenCount = 10; // Well below threshold const originalTokenCount = 10; // Well below threshold
const newTokenCount = 5; const newTokenCount = 5;

View File

@ -30,6 +30,7 @@ import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat } from './geminiChat.js';
import { retryWithBackoff } from '../utils/retry.js'; import { retryWithBackoff } from '../utils/retry.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
import { tokenLimit } from './tokenLimits.js'; import { tokenLimit } from './tokenLimits.js';
import { import {
AuthType, AuthType,
@ -547,7 +548,8 @@ export class GeminiClient {
// Find the first user message after the index. This is the start of the next turn. // Find the first user message after the index. This is the start of the next turn.
while ( while (
compressBeforeIndex < curatedHistory.length && compressBeforeIndex < curatedHistory.length &&
curatedHistory[compressBeforeIndex]?.role !== 'user' (curatedHistory[compressBeforeIndex]?.role === 'model' ||
isFunctionResponse(curatedHistory[compressBeforeIndex]))
) { ) {
compressBeforeIndex++; compressBeforeIndex++;
} }