From 9537ff476219486574fb6a50e54389a78beefe8e Mon Sep 17 00:00:00 2001 From: Taylor Mullen Date: Thu, 29 May 2025 23:53:35 -0700 Subject: [PATCH] feat(server): consolidate adjacent model content in chat history - Consolidates consecutive model messages into a single message in the chat history. - This prevents multiple model messages from being displayed in a row, improving readability. - This may also address some instances of 500 errors that could have been caused by multiple, rapidly succeeding model messages. - Adds tests for the new consolidation logic. Fixes https://b.corp.google.com/issues/421010429 --- packages/server/src/core/geminiChat.test.ts | 282 ++++++++++++++++++++ packages/server/src/core/geminiChat.ts | 40 ++- 2 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 packages/server/src/core/geminiChat.test.ts diff --git a/packages/server/src/core/geminiChat.test.ts b/packages/server/src/core/geminiChat.test.ts new file mode 100644 index 00000000..11e222c9 --- /dev/null +++ b/packages/server/src/core/geminiChat.test.ts @@ -0,0 +1,282 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + Content, + GoogleGenAI, + Models, + GenerateContentConfig, + Part, +} from '@google/genai'; +import { GeminiChat } from './geminiChat.js'; + +// Mocks +const mockModelsModule = { + generateContent: vi.fn(), + generateContentStream: vi.fn(), + countTokens: vi.fn(), + embedContent: vi.fn(), + batchEmbedContents: vi.fn(), +} as unknown as Models; + +const mockGoogleGenAI = { + getGenerativeModel: vi.fn().mockReturnValue(mockModelsModule), +} as unknown as GoogleGenAI; + +describe('GeminiChat', () => { + let chat: GeminiChat; + const model = 'gemini-pro'; + const config: GenerateContentConfig = {}; + + beforeEach(() => { + vi.clearAllMocks(); + // Reset history for each test by creating a new instance + chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, []); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('recordHistory', () => { + const userInput: Content = { + role: 'user', + parts: [{ text: 'User input' }], + }; + + it('should add user input and a single model output to history', () => { + const modelOutput: Content[] = [ + { role: 'model', parts: [{ text: 'Model output' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, modelOutput); + const history = chat.getHistory(); + expect(history).toEqual([userInput, modelOutput[0]]); + }); + + it('should consolidate adjacent model outputs', () => { + const modelOutputParts: Content[] = [ + { role: 'model', parts: [{ text: 'Model part 1' }] }, + { role: 'model', parts: [{ text: 'Model part 2' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, modelOutputParts); + const history = chat.getHistory(); + expect(history.length).toBe(2); + expect(history[0]).toEqual(userInput); + expect(history[1].role).toBe('model'); + expect(history[1].parts).toEqual([ + { text: 'Model part 1' }, + { text: 'Model part 2' }, + ]); + }); + + it('should handle a mix of user and model roles in outputContents (though unusual)', () => { + const mixedOutput: Content[] = [ + { role: 'model', parts: [{ text: 'Model 1' }] }, + { role: 'user', parts: [{ text: 'Unexpected User' }] }, // This should be pushed as is + { role: 'model', parts: [{ text: 'Model 2' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, mixedOutput); + const history = chat.getHistory(); + expect(history.length).toBe(4); // user, model1, user_unexpected, model2 + expect(history[0]).toEqual(userInput); + expect(history[1]).toEqual(mixedOutput[0]); + expect(history[2]).toEqual(mixedOutput[1]); + expect(history[3]).toEqual(mixedOutput[2]); + }); + + it('should consolidate multiple adjacent model outputs correctly', () => { + const modelOutputParts: Content[] = [ + { role: 'model', parts: [{ text: 'M1' }] }, + { role: 'model', parts: [{ text: 'M2' }] }, + { role: 'model', parts: [{ text: 'M3' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, modelOutputParts); + const history = chat.getHistory(); + expect(history.length).toBe(2); + expect(history[1].parts).toEqual([ + { text: 'M1' }, + { text: 'M2' }, + { text: 'M3' }, + ]); + }); + + it('should not consolidate if roles are different between model outputs', () => { + const modelOutputParts: Content[] = [ + { role: 'model', parts: [{ text: 'M1' }] }, + { role: 'user', parts: [{ text: 'Interjecting User' }] }, + { role: 'model', parts: [{ text: 'M2' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, modelOutputParts); + const history = chat.getHistory(); + expect(history.length).toBe(4); // user, M1, Interjecting User, M2 + expect(history[1].parts).toEqual([{ text: 'M1' }]); + expect(history[3].parts).toEqual([{ text: 'M2' }]); + }); + + it('should merge with last history entry if it is also a model output', () => { + // @ts-expect-error Accessing private property for test setup + chat.history = [ + userInput, + { role: 'model', parts: [{ text: 'Initial Model Output' }] }, + ]; // Prime the history + + const newModelOutput: Content[] = [ + { role: 'model', parts: [{ text: 'New Model Part 1' }] }, + { role: 'model', parts: [{ text: 'New Model Part 2' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed + + // const history = chat.getHistory(); // Removed unused variable to satisfy linter + // The recordHistory will push the *new* userInput first, then the consolidated newModelOutput. + // However, the consolidation logic for *outputContents* itself should run, and then the merge with *existing* history. + // Let's adjust the test to reflect how recordHistory is used: it adds the current userInput, then the model's response to it. + + // Reset and set up a more realistic scenario for merging with existing history + chat = new GeminiChat( + mockGoogleGenAI, + mockModelsModule, + model, + config, + [], + ); + const firstUserInput: Content = { + role: 'user', + parts: [{ text: 'First user input' }], + }; + const firstModelOutput: Content[] = [ + { role: 'model', parts: [{ text: 'First model response' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(firstUserInput, firstModelOutput); + + const secondUserInput: Content = { + role: 'user', + parts: [{ text: 'Second user input' }], + }; + const secondModelOutput: Content[] = [ + { role: 'model', parts: [{ text: 'Second model response part 1' }] }, + { role: 'model', parts: [{ text: 'Second model response part 2' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(secondUserInput, secondModelOutput); + + const finalHistory = chat.getHistory(); + expect(finalHistory.length).toBe(4); // user1, model1, user2, model2(consolidated) + expect(finalHistory[0]).toEqual(firstUserInput); + expect(finalHistory[1]).toEqual(firstModelOutput[0]); + expect(finalHistory[2]).toEqual(secondUserInput); + expect(finalHistory[3].role).toBe('model'); + expect(finalHistory[3].parts).toEqual([ + { text: 'Second model response part 1' }, + { text: 'Second model response part 2' }, + ]); + }); + + it('should correctly merge consolidated new output with existing model history', () => { + // Setup: history ends with a model turn + const initialUser: Content = { + role: 'user', + parts: [{ text: 'Initial user query' }], + }; + const initialModel: Content = { + role: 'model', + parts: [{ text: 'Initial model answer.' }], + }; + chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, [ + initialUser, + initialModel, + ]); + + // New interaction + const currentUserInput: Content = { + role: 'user', + parts: [{ text: 'Follow-up question' }], + }; + const newModelParts: Content[] = [ + { role: 'model', parts: [{ text: 'Part A of new answer.' }] }, + { role: 'model', parts: [{ text: 'Part B of new answer.' }] }, + ]; + + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(currentUserInput, newModelParts); + const history = chat.getHistory(); + + // Expected: initialUser, initialModel, currentUserInput, consolidatedNewModelParts + expect(history.length).toBe(4); + expect(history[0]).toEqual(initialUser); + expect(history[1]).toEqual(initialModel); + expect(history[2]).toEqual(currentUserInput); + expect(history[3].role).toBe('model'); + expect(history[3].parts).toEqual([ + { text: 'Part A of new answer.' }, + { text: 'Part B of new answer.' }, + ]); + }); + + it('should handle empty modelOutput array', () => { + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, []); + const history = chat.getHistory(); + // If modelOutput is empty, it might push a default empty model part depending on isFunctionResponse + // Assuming isFunctionResponse(userInput) is false for this simple text input + expect(history.length).toBe(2); + expect(history[0]).toEqual(userInput); + expect(history[1].role).toBe('model'); + expect(history[1].parts).toEqual([]); + }); + + it('should handle modelOutput with parts being undefined or empty (if they pass initial every check)', () => { + const modelOutputUndefinedParts: Content[] = [ + { role: 'model', parts: [{ text: 'Text part' }] }, + { role: 'model', parts: undefined as unknown as Part[] }, // Test undefined parts + { role: 'model', parts: [] }, // Test empty parts array + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, modelOutputUndefinedParts); + const history = chat.getHistory(); + expect(history.length).toBe(2); + expect(history[1].role).toBe('model'); + // The consolidation logic should handle undefined/empty parts by spreading `|| []` + expect(history[1].parts).toEqual([{ text: 'Text part' }]); + }); + + it('should correctly handle automaticFunctionCallingHistory', () => { + const afcHistory: Content[] = [ + { role: 'user', parts: [{ text: 'AFC User' }] }, + { role: 'model', parts: [{ text: 'AFC Model' }] }, + ]; + const modelOutput: Content[] = [ + { role: 'model', parts: [{ text: 'Regular Model Output' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, modelOutput, afcHistory); + const history = chat.getHistory(); + expect(history.length).toBe(3); + expect(history[0]).toEqual(afcHistory[0]); + expect(history[1]).toEqual(afcHistory[1]); + expect(history[2]).toEqual(modelOutput[0]); + }); + + it('should add userInput if AFC history is present but empty', () => { + const modelOutput: Content[] = [ + { role: 'model', parts: [{ text: 'Model Output' }] }, + ]; + // @ts-expect-error Accessing private method for testing purposes + chat.recordHistory(userInput, modelOutput, []); // Empty AFC history + const history = chat.getHistory(); + expect(history.length).toBe(2); + expect(history[0]).toEqual(userInput); + expect(history[1]).toEqual(modelOutput[0]); + }); + }); +}); diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts index 5ba8ce2d..877d0825 100644 --- a/packages/server/src/core/geminiChat.ts +++ b/packages/server/src/core/geminiChat.ts @@ -313,6 +313,44 @@ export class GeminiChat { } else { this.history.push(userInput); } - this.history.push(...outputContents); + + // Consolidate adjacent model roles in outputContents + const consolidatedOutputContents: Content[] = []; + for (const content of outputContents) { + const lastContent = + consolidatedOutputContents[consolidatedOutputContents.length - 1]; + if ( + lastContent && + lastContent.role === 'model' && + content.role === 'model' && + lastContent.parts + ) { + lastContent.parts.push(...(content.parts || [])); + } else { + consolidatedOutputContents.push(content); + } + } + + if (consolidatedOutputContents.length > 0) { + const lastHistoryEntry = this.history[this.history.length - 1]; + // Only merge if AFC history was NOT just added, to prevent merging with last AFC model turn. + const canMergeWithLastHistory = + !automaticFunctionCallingHistory || + automaticFunctionCallingHistory.length === 0; + + if ( + canMergeWithLastHistory && + lastHistoryEntry && + lastHistoryEntry.role === 'model' && + lastHistoryEntry.parts && + consolidatedOutputContents[0].role === 'model' + ) { + lastHistoryEntry.parts.push( + ...(consolidatedOutputContents[0].parts || []), + ); + consolidatedOutputContents.shift(); // Remove the first element as it's merged + } + this.history.push(...consolidatedOutputContents); + } } }