diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index bbcb549b..03793bda 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -470,34 +470,31 @@ describe('Gemini Client (client.ts)', () => { describe('tryCompressChat', () => { const mockCountTokens = vi.fn(); const mockSendMessage = vi.fn(); + const mockGetHistory = vi.fn(); beforeEach(() => { vi.mock('./tokenLimits', () => ({ tokenLimit: vi.fn(), })); - const mockGenerator: Partial = { + client['contentGenerator'] = { countTokens: mockCountTokens, - }; - client['contentGenerator'] = mockGenerator as ContentGenerator; + } as unknown as ContentGenerator; - // Mock the chat's sendMessage method - const mockChat: Partial = { - getHistory: vi - .fn() - .mockReturnValue([ - { role: 'user', parts: [{ text: '...history...' }] }, - ]), + client['chat'] = { + getHistory: mockGetHistory, addHistory: vi.fn(), setHistory: vi.fn(), sendMessage: mockSendMessage, - }; - client['chat'] = mockChat as GeminiChat; + } as unknown as GeminiChat; }); it('should not trigger summarization if token count is below threshold', async () => { const MOCKED_TOKEN_LIMIT = 1000; vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT); + mockGetHistory.mockReturnValue([ + { role: 'user', parts: [{ text: '...history...' }] }, + ]); mockCountTokens.mockResolvedValue({ totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7 @@ -515,6 +512,9 @@ describe('Gemini Client (client.ts)', () => { it('should trigger summarization if token count is at threshold', async () => { const MOCKED_TOKEN_LIMIT = 1000; vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT); + mockGetHistory.mockReturnValue([ + { role: 'user', parts: [{ text: '...history...' }] }, + ]); const originalTokenCount = 1000 * 0.7; const newTokenCount = 100; @@ -546,7 +546,69 @@ describe('Gemini Client (client.ts)', () => { expect(newChat).not.toBe(initialChat); }); + it('should not compress across a function call response', async () => { + const MOCKED_TOKEN_LIMIT = 1000; + vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT); + mockGetHistory.mockReturnValue([ + { role: 'user', parts: [{ text: '...history 1...' }] }, + { role: 'model', parts: [{ text: '...history 2...' }] }, + { role: 'user', parts: [{ text: '...history 3...' }] }, + { role: 'model', parts: [{ text: '...history 4...' }] }, + { role: 'user', parts: [{ text: '...history 5...' }] }, + { role: 'model', parts: [{ text: '...history 6...' }] }, + { role: 'user', parts: [{ text: '...history 7...' }] }, + { role: 'model', parts: [{ text: '...history 8...' }] }, + // Normally we would break here, but we have a function response. + { + role: 'user', + parts: [{ functionResponse: { name: '...history 8...' } }], + }, + { role: 'model', parts: [{ text: '...history 10...' }] }, + // Instead we will break here. + { role: 'user', parts: [{ text: '...history 10...' }] }, + ]); + + const originalTokenCount = 1000 * 0.7; + const newTokenCount = 100; + + mockCountTokens + .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check + .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history + + // Mock the summary response from the chat + mockSendMessage.mockResolvedValue({ + role: 'model', + parts: [{ text: 'This is a summary.' }], + }); + + const initialChat = client.getChat(); + const result = await client.tryCompressChat('prompt-id-3'); + const newChat = client.getChat(); + + expect(tokenLimit).toHaveBeenCalled(); + expect(mockSendMessage).toHaveBeenCalled(); + + // Assert that summarization happened and returned the correct stats + expect(result).toEqual({ + originalTokenCount, + newTokenCount, + }); + // Assert that the chat was reset + expect(newChat).not.toBe(initialChat); + + // 1. standard start context message + // 2. standard canned user start message + // 3. compressed summary message + // 4. standard canned user summary message + // 5. The last user message (not the last 3 because that would start with a function response) + expect(newChat.getHistory().length).toEqual(5); + }); + it('should always trigger summarization when force is true, regardless of token count', async () => { + mockGetHistory.mockReturnValue([ + { role: 'user', parts: [{ text: '...history...' }] }, + ]); + const originalTokenCount = 10; // Well below threshold const newTokenCount = 5; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index ed903788..d8143d05 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -30,6 +30,7 @@ import { reportError } from '../utils/errorReporting.js'; import { GeminiChat } from './geminiChat.js'; import { retryWithBackoff } from '../utils/retry.js'; import { getErrorMessage } from '../utils/errors.js'; +import { isFunctionResponse } from '../utils/messageInspectors.js'; import { tokenLimit } from './tokenLimits.js'; import { AuthType, @@ -547,7 +548,8 @@ export class GeminiClient { // Find the first user message after the index. This is the start of the next turn. while ( compressBeforeIndex < curatedHistory.length && - curatedHistory[compressBeforeIndex]?.role !== 'user' + (curatedHistory[compressBeforeIndex]?.role === 'model' || + isFunctionResponse(curatedHistory[compressBeforeIndex])) ) { compressBeforeIndex++; }