Preserve recent history when compressing. (#3049)

Co-authored-by: Scott Densmore <scottdensmore@mac.com>
This commit is contained in:
Tommaso Sciortino 2025-07-07 23:32:09 -07:00 committed by GitHub
parent 23e3c7d6ec
commit 0c70a99b56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 118 additions and 8 deletions

View File

@ -8,11 +8,12 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { import {
Chat, Chat,
Content,
EmbedContentResponse, EmbedContentResponse,
GenerateContentResponse, GenerateContentResponse,
GoogleGenAI, GoogleGenAI,
} from '@google/genai'; } from '@google/genai';
import { GeminiClient } from './client.js'; import { findIndexAfterFraction, GeminiClient } from './client.js';
import { AuthType, ContentGenerator } from './contentGenerator.js'; import { AuthType, ContentGenerator } from './contentGenerator.js';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
@ -65,6 +66,54 @@ vi.mock('../telemetry/index.js', () => ({
logApiError: vi.fn(), logApiError: vi.fn(),
})); }));
describe('findIndexAfterFraction', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] },
{ role: 'model', parts: [{ text: 'This is the second message.' }] },
{ role: 'user', parts: [{ text: 'This is the third message.' }] },
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] },
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] },
];
it('should throw an error for non-positive numbers', () => {
expect(() => findIndexAfterFraction(history, 0)).toThrow(
'Fraction must be between 0 and 1',
);
});
it('should throw an error for a fraction greater than or equal to 1', () => {
expect(() => findIndexAfterFraction(history, 1)).toThrow(
'Fraction must be between 0 and 1',
);
});
it('should handle a fraction in the middle', () => {
// Total length is 257. 257 * 0.5 = 128.5
// 0: 53
// 1: 53 + 54 = 107
// 2: 107 + 53 = 160
// 160 >= 128.5, so index is 2
expect(findIndexAfterFraction(history, 0.5)).toBe(2);
});
it('should handle an empty history', () => {
expect(findIndexAfterFraction([], 0.5)).toBe(0);
});
it('should handle a history with only one item', () => {
expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(0);
});
it('should handle history with weird parts', () => {
const historyWithEmptyParts: Content[] = [
{ role: 'user', parts: [{ text: 'Message 1' }] },
{ role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] },
{ role: 'user', parts: [{ text: 'Message 2' }] },
];
expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(1);
});
});
describe('Gemini Client (client.ts)', () => { describe('Gemini Client (client.ts)', () => {
let client: GeminiClient; let client: GeminiClient;
beforeEach(async () => { beforeEach(async () => {
@ -384,6 +433,7 @@ describe('Gemini Client (client.ts)', () => {
{ role: 'user', parts: [{ text: '...history...' }] }, { role: 'user', parts: [{ text: '...history...' }] },
]), ]),
addHistory: vi.fn(), addHistory: vi.fn(),
setHistory: vi.fn(),
sendMessage: mockSendMessage, sendMessage: mockSendMessage,
}; };
client['chat'] = mockChat as GeminiChat; client['chat'] = mockChat as GeminiChat;
@ -735,6 +785,7 @@ describe('Gemini Client (client.ts)', () => {
const mockChat: Partial<GeminiChat> = { const mockChat: Partial<GeminiChat> = {
getHistory: vi.fn().mockReturnValue(mockChatHistory), getHistory: vi.fn().mockReturnValue(mockChatHistory),
setHistory: vi.fn(),
sendMessage: mockSendMessage, sendMessage: mockSendMessage,
}; };

View File

@ -45,6 +45,39 @@ function isThinkingSupported(model: string) {
return false; return false;
} }
/**
* Returns the index of the content after the fraction of the total characters in the history.
*
* Exported for testing purposes.
*/
export function findIndexAfterFraction(
history: Content[],
fraction: number,
): number {
if (fraction <= 0 || fraction >= 1) {
throw new Error('Fraction must be between 0 and 1');
}
const contentLengths = history.map(
(content) => JSON.stringify(content).length,
);
const totalCharacters = contentLengths.reduce(
(sum, length) => sum + length,
0,
);
const targetCharacters = totalCharacters * fraction;
let charactersSoFar = 0;
for (let i = 0; i < contentLengths.length; i++) {
charactersSoFar += contentLengths[i];
if (charactersSoFar >= targetCharacters) {
return i;
}
}
return contentLengths.length;
}
export class GeminiClient { export class GeminiClient {
private chat?: GeminiChat; private chat?: GeminiChat;
private contentGenerator?: ContentGenerator; private contentGenerator?: ContentGenerator;
@ -54,7 +87,16 @@ export class GeminiClient {
topP: 1, topP: 1,
}; };
private readonly MAX_TURNS = 100; private readonly MAX_TURNS = 100;
private readonly TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7; /**
* Threshold for compression token count as a fraction of the model's token limit.
* If the chat history exceeds this threshold, it will be compressed.
*/
private readonly COMPRESSION_TOKEN_THRESHOLD = 0.7;
/**
* The fraction of the latest chat history to keep. A value of 0.3
* means that only the last 30% of the chat history will be kept after compression.
*/
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
constructor(private config: Config) { constructor(private config: Config) {
if (config.getProxy()) { if (config.getProxy()) {
@ -90,11 +132,11 @@ export class GeminiClient {
return this.chat; return this.chat;
} }
async getHistory(): Promise<Content[]> { getHistory(): Content[] {
return this.getChat().getHistory(); return this.getChat().getHistory();
} }
async setHistory(history: Content[]): Promise<void> { setHistory(history: Content[]) {
this.getChat().setHistory(history); this.getChat().setHistory(history);
} }
@ -441,25 +483,41 @@ export class GeminiClient {
const model = this.config.getModel(); const model = this.config.getModel();
let { totalTokens: originalTokenCount } = const { totalTokens: originalTokenCount } =
await this.getContentGenerator().countTokens({ await this.getContentGenerator().countTokens({
model, model,
contents: curatedHistory, contents: curatedHistory,
}); });
if (originalTokenCount === undefined) { if (originalTokenCount === undefined) {
console.warn(`Could not determine token count for model ${model}.`); console.warn(`Could not determine token count for model ${model}.`);
originalTokenCount = 0; return null;
} }
// Don't compress if not forced and we are under the limit. // Don't compress if not forced and we are under the limit.
if ( if (
!force && !force &&
originalTokenCount < originalTokenCount < this.COMPRESSION_TOKEN_THRESHOLD * tokenLimit(model)
this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(model)
) { ) {
return null; return null;
} }
let compressBeforeIndex = findIndexAfterFraction(
curatedHistory,
1 - this.COMPRESSION_PRESERVE_THRESHOLD,
);
// Find the first user message after the index. This is the start of the next turn.
while (
compressBeforeIndex < curatedHistory.length &&
curatedHistory[compressBeforeIndex]?.role !== 'user'
) {
compressBeforeIndex++;
}
const historyToCompress = curatedHistory.slice(0, compressBeforeIndex);
const historyToKeep = curatedHistory.slice(compressBeforeIndex);
this.getChat().setHistory(historyToCompress);
const { text: summary } = await this.getChat().sendMessage({ const { text: summary } = await this.getChat().sendMessage({
message: { message: {
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.', text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
@ -477,6 +535,7 @@ export class GeminiClient {
role: 'model', role: 'model',
parts: [{ text: 'Got it. Thanks for the additional context!' }], parts: [{ text: 'Got it. Thanks for the additional context!' }],
}, },
...historyToKeep,
]); ]);
const { totalTokens: newTokenCount } = const { totalTokens: newTokenCount } =