From 0c70a99b567d46cb3a90774d41a72670d6727dfb Mon Sep 17 00:00:00 2001 From: Tommaso Sciortino Date: Mon, 7 Jul 2025 23:32:09 -0700 Subject: [PATCH] Preserve recent history when compressing. (#3049) Co-authored-by: Scott Densmore --- packages/core/src/core/client.test.ts | 53 ++++++++++++++++++- packages/core/src/core/client.ts | 73 ++++++++++++++++++++++++--- 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index dc3b8455..9d3791fd 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -8,11 +8,12 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { Chat, + Content, EmbedContentResponse, GenerateContentResponse, GoogleGenAI, } from '@google/genai'; -import { GeminiClient } from './client.js'; +import { findIndexAfterFraction, GeminiClient } from './client.js'; import { AuthType, ContentGenerator } from './contentGenerator.js'; import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; @@ -65,6 +66,54 @@ vi.mock('../telemetry/index.js', () => ({ logApiError: vi.fn(), })); +describe('findIndexAfterFraction', () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'This is the first message.' }] }, + { role: 'model', parts: [{ text: 'This is the second message.' }] }, + { role: 'user', parts: [{ text: 'This is the third message.' }] }, + { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, + { role: 'user', parts: [{ text: 'This is the fifth message.' }] }, + ]; + + it('should throw an error for non-positive numbers', () => { + expect(() => findIndexAfterFraction(history, 0)).toThrow( + 'Fraction must be between 0 and 1', + ); + }); + + it('should throw an error for a fraction greater than or equal to 1', () => { + expect(() => findIndexAfterFraction(history, 1)).toThrow( + 'Fraction must be between 0 and 1', + ); + }); + + it('should handle a fraction in the middle', () => { + // Total length is 257. 257 * 0.5 = 128.5 + // 0: 53 + // 1: 53 + 54 = 107 + // 2: 107 + 53 = 160 + // 160 >= 128.5, so index is 2 + expect(findIndexAfterFraction(history, 0.5)).toBe(2); + }); + + it('should handle an empty history', () => { + expect(findIndexAfterFraction([], 0.5)).toBe(0); + }); + + it('should handle a history with only one item', () => { + expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(0); + }); + + it('should handle history with weird parts', () => { + const historyWithEmptyParts: Content[] = [ + { role: 'user', parts: [{ text: 'Message 1' }] }, + { role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] }, + { role: 'user', parts: [{ text: 'Message 2' }] }, + ]; + expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(1); + }); +}); + describe('Gemini Client (client.ts)', () => { let client: GeminiClient; beforeEach(async () => { @@ -384,6 +433,7 @@ describe('Gemini Client (client.ts)', () => { { role: 'user', parts: [{ text: '...history...' }] }, ]), addHistory: vi.fn(), + setHistory: vi.fn(), sendMessage: mockSendMessage, }; client['chat'] = mockChat as GeminiChat; @@ -735,6 +785,7 @@ describe('Gemini Client (client.ts)', () => { const mockChat: Partial = { getHistory: vi.fn().mockReturnValue(mockChatHistory), + setHistory: vi.fn(), sendMessage: mockSendMessage, }; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 69ed0dff..6cfcd407 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -45,6 +45,39 @@ function isThinkingSupported(model: string) { return false; } +/** + * Returns the index of the content after the fraction of the total characters in the history. + * + * Exported for testing purposes. + */ +export function findIndexAfterFraction( + history: Content[], + fraction: number, +): number { + if (fraction <= 0 || fraction >= 1) { + throw new Error('Fraction must be between 0 and 1'); + } + + const contentLengths = history.map( + (content) => JSON.stringify(content).length, + ); + + const totalCharacters = contentLengths.reduce( + (sum, length) => sum + length, + 0, + ); + const targetCharacters = totalCharacters * fraction; + + let charactersSoFar = 0; + for (let i = 0; i < contentLengths.length; i++) { + charactersSoFar += contentLengths[i]; + if (charactersSoFar >= targetCharacters) { + return i; + } + } + return contentLengths.length; +} + export class GeminiClient { private chat?: GeminiChat; private contentGenerator?: ContentGenerator; @@ -54,7 +87,16 @@ export class GeminiClient { topP: 1, }; private readonly MAX_TURNS = 100; - private readonly TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7; + /** + * Threshold for compression token count as a fraction of the model's token limit. + * If the chat history exceeds this threshold, it will be compressed. + */ + private readonly COMPRESSION_TOKEN_THRESHOLD = 0.7; + /** + * The fraction of the latest chat history to keep. A value of 0.3 + * means that only the last 30% of the chat history will be kept after compression. + */ + private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3; constructor(private config: Config) { if (config.getProxy()) { @@ -90,11 +132,11 @@ export class GeminiClient { return this.chat; } - async getHistory(): Promise { + getHistory(): Content[] { return this.getChat().getHistory(); } - async setHistory(history: Content[]): Promise { + setHistory(history: Content[]) { this.getChat().setHistory(history); } @@ -441,25 +483,41 @@ export class GeminiClient { const model = this.config.getModel(); - let { totalTokens: originalTokenCount } = + const { totalTokens: originalTokenCount } = await this.getContentGenerator().countTokens({ model, contents: curatedHistory, }); if (originalTokenCount === undefined) { console.warn(`Could not determine token count for model ${model}.`); - originalTokenCount = 0; + return null; } // Don't compress if not forced and we are under the limit. if ( !force && - originalTokenCount < - this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(model) + originalTokenCount < this.COMPRESSION_TOKEN_THRESHOLD * tokenLimit(model) ) { return null; } + let compressBeforeIndex = findIndexAfterFraction( + curatedHistory, + 1 - this.COMPRESSION_PRESERVE_THRESHOLD, + ); + // Find the first user message after the index. This is the start of the next turn. + while ( + compressBeforeIndex < curatedHistory.length && + curatedHistory[compressBeforeIndex]?.role !== 'user' + ) { + compressBeforeIndex++; + } + + const historyToCompress = curatedHistory.slice(0, compressBeforeIndex); + const historyToKeep = curatedHistory.slice(compressBeforeIndex); + + this.getChat().setHistory(historyToCompress); + const { text: summary } = await this.getChat().sendMessage({ message: { text: 'First, reason in your scratchpad. Then, generate the .', @@ -477,6 +535,7 @@ export class GeminiClient { role: 'model', parts: [{ text: 'Got it. Thanks for the additional context!' }], }, + ...historyToKeep, ]); const { totalTokens: newTokenCount } =