Preserve recent history when compressing. (#3049)
Co-authored-by: Scott Densmore <scottdensmore@mac.com>
This commit is contained in:
parent
23e3c7d6ec
commit
0c70a99b56
|
@ -8,11 +8,12 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Chat,
|
Chat,
|
||||||
|
Content,
|
||||||
EmbedContentResponse,
|
EmbedContentResponse,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
GoogleGenAI,
|
GoogleGenAI,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { GeminiClient } from './client.js';
|
import { findIndexAfterFraction, GeminiClient } from './client.js';
|
||||||
import { AuthType, ContentGenerator } from './contentGenerator.js';
|
import { AuthType, ContentGenerator } from './contentGenerator.js';
|
||||||
import { GeminiChat } from './geminiChat.js';
|
import { GeminiChat } from './geminiChat.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
|
@ -65,6 +66,54 @@ vi.mock('../telemetry/index.js', () => ({
|
||||||
logApiError: vi.fn(),
|
logApiError: vi.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
describe('findIndexAfterFraction', () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the first message.' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the second message.' }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the third message.' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] },
|
||||||
|
];
|
||||||
|
|
||||||
|
it('should throw an error for non-positive numbers', () => {
|
||||||
|
expect(() => findIndexAfterFraction(history, 0)).toThrow(
|
||||||
|
'Fraction must be between 0 and 1',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error for a fraction greater than or equal to 1', () => {
|
||||||
|
expect(() => findIndexAfterFraction(history, 1)).toThrow(
|
||||||
|
'Fraction must be between 0 and 1',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle a fraction in the middle', () => {
|
||||||
|
// Total length is 257. 257 * 0.5 = 128.5
|
||||||
|
// 0: 53
|
||||||
|
// 1: 53 + 54 = 107
|
||||||
|
// 2: 107 + 53 = 160
|
||||||
|
// 160 >= 128.5, so index is 2
|
||||||
|
expect(findIndexAfterFraction(history, 0.5)).toBe(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle an empty history', () => {
|
||||||
|
expect(findIndexAfterFraction([], 0.5)).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle a history with only one item', () => {
|
||||||
|
expect(findIndexAfterFraction(history.slice(0, 1), 0.5)).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle history with weird parts', () => {
|
||||||
|
const historyWithEmptyParts: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Message 1' }] },
|
||||||
|
{ role: 'model', parts: [{ fileData: { fileUri: 'derp' } }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'Message 2' }] },
|
||||||
|
];
|
||||||
|
expect(findIndexAfterFraction(historyWithEmptyParts, 0.5)).toBe(1);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('Gemini Client (client.ts)', () => {
|
describe('Gemini Client (client.ts)', () => {
|
||||||
let client: GeminiClient;
|
let client: GeminiClient;
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
|
@ -384,6 +433,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
{ role: 'user', parts: [{ text: '...history...' }] },
|
{ role: 'user', parts: [{ text: '...history...' }] },
|
||||||
]),
|
]),
|
||||||
addHistory: vi.fn(),
|
addHistory: vi.fn(),
|
||||||
|
setHistory: vi.fn(),
|
||||||
sendMessage: mockSendMessage,
|
sendMessage: mockSendMessage,
|
||||||
};
|
};
|
||||||
client['chat'] = mockChat as GeminiChat;
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
@ -735,6 +785,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
|
|
||||||
const mockChat: Partial<GeminiChat> = {
|
const mockChat: Partial<GeminiChat> = {
|
||||||
getHistory: vi.fn().mockReturnValue(mockChatHistory),
|
getHistory: vi.fn().mockReturnValue(mockChatHistory),
|
||||||
|
setHistory: vi.fn(),
|
||||||
sendMessage: mockSendMessage,
|
sendMessage: mockSendMessage,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -45,6 +45,39 @@ function isThinkingSupported(model: string) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the index of the content after the fraction of the total characters in the history.
|
||||||
|
*
|
||||||
|
* Exported for testing purposes.
|
||||||
|
*/
|
||||||
|
export function findIndexAfterFraction(
|
||||||
|
history: Content[],
|
||||||
|
fraction: number,
|
||||||
|
): number {
|
||||||
|
if (fraction <= 0 || fraction >= 1) {
|
||||||
|
throw new Error('Fraction must be between 0 and 1');
|
||||||
|
}
|
||||||
|
|
||||||
|
const contentLengths = history.map(
|
||||||
|
(content) => JSON.stringify(content).length,
|
||||||
|
);
|
||||||
|
|
||||||
|
const totalCharacters = contentLengths.reduce(
|
||||||
|
(sum, length) => sum + length,
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
const targetCharacters = totalCharacters * fraction;
|
||||||
|
|
||||||
|
let charactersSoFar = 0;
|
||||||
|
for (let i = 0; i < contentLengths.length; i++) {
|
||||||
|
charactersSoFar += contentLengths[i];
|
||||||
|
if (charactersSoFar >= targetCharacters) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentLengths.length;
|
||||||
|
}
|
||||||
|
|
||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
private chat?: GeminiChat;
|
private chat?: GeminiChat;
|
||||||
private contentGenerator?: ContentGenerator;
|
private contentGenerator?: ContentGenerator;
|
||||||
|
@ -54,7 +87,16 @@ export class GeminiClient {
|
||||||
topP: 1,
|
topP: 1,
|
||||||
};
|
};
|
||||||
private readonly MAX_TURNS = 100;
|
private readonly MAX_TURNS = 100;
|
||||||
private readonly TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7;
|
/**
|
||||||
|
* Threshold for compression token count as a fraction of the model's token limit.
|
||||||
|
* If the chat history exceeds this threshold, it will be compressed.
|
||||||
|
*/
|
||||||
|
private readonly COMPRESSION_TOKEN_THRESHOLD = 0.7;
|
||||||
|
/**
|
||||||
|
* The fraction of the latest chat history to keep. A value of 0.3
|
||||||
|
* means that only the last 30% of the chat history will be kept after compression.
|
||||||
|
*/
|
||||||
|
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
||||||
|
|
||||||
constructor(private config: Config) {
|
constructor(private config: Config) {
|
||||||
if (config.getProxy()) {
|
if (config.getProxy()) {
|
||||||
|
@ -90,11 +132,11 @@ export class GeminiClient {
|
||||||
return this.chat;
|
return this.chat;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getHistory(): Promise<Content[]> {
|
getHistory(): Content[] {
|
||||||
return this.getChat().getHistory();
|
return this.getChat().getHistory();
|
||||||
}
|
}
|
||||||
|
|
||||||
async setHistory(history: Content[]): Promise<void> {
|
setHistory(history: Content[]) {
|
||||||
this.getChat().setHistory(history);
|
this.getChat().setHistory(history);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -441,25 +483,41 @@ export class GeminiClient {
|
||||||
|
|
||||||
const model = this.config.getModel();
|
const model = this.config.getModel();
|
||||||
|
|
||||||
let { totalTokens: originalTokenCount } =
|
const { totalTokens: originalTokenCount } =
|
||||||
await this.getContentGenerator().countTokens({
|
await this.getContentGenerator().countTokens({
|
||||||
model,
|
model,
|
||||||
contents: curatedHistory,
|
contents: curatedHistory,
|
||||||
});
|
});
|
||||||
if (originalTokenCount === undefined) {
|
if (originalTokenCount === undefined) {
|
||||||
console.warn(`Could not determine token count for model ${model}.`);
|
console.warn(`Could not determine token count for model ${model}.`);
|
||||||
originalTokenCount = 0;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't compress if not forced and we are under the limit.
|
// Don't compress if not forced and we are under the limit.
|
||||||
if (
|
if (
|
||||||
!force &&
|
!force &&
|
||||||
originalTokenCount <
|
originalTokenCount < this.COMPRESSION_TOKEN_THRESHOLD * tokenLimit(model)
|
||||||
this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(model)
|
|
||||||
) {
|
) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let compressBeforeIndex = findIndexAfterFraction(
|
||||||
|
curatedHistory,
|
||||||
|
1 - this.COMPRESSION_PRESERVE_THRESHOLD,
|
||||||
|
);
|
||||||
|
// Find the first user message after the index. This is the start of the next turn.
|
||||||
|
while (
|
||||||
|
compressBeforeIndex < curatedHistory.length &&
|
||||||
|
curatedHistory[compressBeforeIndex]?.role !== 'user'
|
||||||
|
) {
|
||||||
|
compressBeforeIndex++;
|
||||||
|
}
|
||||||
|
|
||||||
|
const historyToCompress = curatedHistory.slice(0, compressBeforeIndex);
|
||||||
|
const historyToKeep = curatedHistory.slice(compressBeforeIndex);
|
||||||
|
|
||||||
|
this.getChat().setHistory(historyToCompress);
|
||||||
|
|
||||||
const { text: summary } = await this.getChat().sendMessage({
|
const { text: summary } = await this.getChat().sendMessage({
|
||||||
message: {
|
message: {
|
||||||
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
|
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
|
||||||
|
@ -477,6 +535,7 @@ export class GeminiClient {
|
||||||
role: 'model',
|
role: 'model',
|
||||||
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
||||||
},
|
},
|
||||||
|
...historyToKeep,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
const { totalTokens: newTokenCount } =
|
const { totalTokens: newTokenCount } =
|
||||||
|
|
Loading…
Reference in New Issue