diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index bce2c5e4..d1a59eb1 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -32,10 +32,11 @@ import { logApiResponse, logApiError, } from '../telemetry/index.js'; +import { ContentGenerator } from './contentGenerator.js'; export class GeminiClient { private chat: Promise; - private client: GoogleGenAI; + private contentGenerator: ContentGenerator; private model: string; private generateContentConfig: GenerateContentConfig = { temperature: 0, @@ -48,7 +49,7 @@ export class GeminiClient { const apiKeyFromConfig = config.getApiKey(); const vertexaiFlag = config.getVertexAI(); - this.client = new GoogleGenAI({ + const googleGenAI = new GoogleGenAI({ apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig, vertexai: vertexaiFlag, httpOptions: { @@ -57,6 +58,7 @@ export class GeminiClient { }, }, }); + this.contentGenerator = googleGenAI.models; this.model = config.getModel(); this.chat = this.startChat(); } @@ -148,8 +150,7 @@ export class GeminiClient { const systemInstruction = getCoreSystemPrompt(userMemory); return new GeminiChat( - this.client, - this.client.models, + this.contentGenerator, this.model, { systemInstruction, @@ -285,7 +286,7 @@ export class GeminiClient { let inputTokenCount = 0; try { - const { totalTokens } = await this.client.models.countTokens({ + const { totalTokens } = await this.contentGenerator.countTokens({ model, contents, }); @@ -300,7 +301,7 @@ export class GeminiClient { this._logApiRequest(model, inputTokenCount); const apiCall = () => - this.client.models.generateContent({ + this.contentGenerator.generateContent({ model, config: { ...requestConfig, @@ -400,7 +401,7 @@ export class GeminiClient { let inputTokenCount = 0; try { - const { totalTokens } = await this.client.models.countTokens({ + const { totalTokens } = await this.contentGenerator.countTokens({ model: modelToUse, contents, }); @@ -415,7 +416,7 @@ export class GeminiClient { this._logApiRequest(modelToUse, inputTokenCount); const apiCall = () => - this.client.models.generateContent({ + this.contentGenerator.generateContent({ model: modelToUse, config: requestConfig, contents, @@ -453,8 +454,7 @@ export class GeminiClient { const chat = await this.chat; const history = chat.getHistory(true); // Get curated history - // Count tokens using the models module from the GoogleGenAI client instance - const { totalTokens } = await this.client.models.countTokens({ + const { totalTokens } = await this.contentGenerator.countTokens({ model: this.model, contents: history, }); diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts new file mode 100644 index 00000000..32b48c5c --- /dev/null +++ b/packages/core/src/core/contentGenerator.ts @@ -0,0 +1,27 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + CountTokensResponse, + GenerateContentResponse, + GenerateContentParameters, + CountTokensParameters, +} from '@google/genai'; + +/** + * Interface abstracting the core functionalities for generating content and counting tokens. + */ +export interface ContentGenerator { + generateContent( + request: GenerateContentParameters, + ): Promise; + + generateContentStream( + request: GenerateContentParameters, + ): Promise>; + + countTokens(request: CountTokensParameters): Promise; +} diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 3a6fb10c..6d18ebd9 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -5,13 +5,7 @@ */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { - Content, - GoogleGenAI, - Models, - GenerateContentConfig, - Part, -} from '@google/genai'; +import { Content, Models, GenerateContentConfig, Part } from '@google/genai'; import { GeminiChat } from './geminiChat.js'; // Mocks @@ -23,10 +17,6 @@ const mockModelsModule = { batchEmbedContents: vi.fn(), } as unknown as Models; -const mockGoogleGenAI = { - getGenerativeModel: vi.fn().mockReturnValue(mockModelsModule), -} as unknown as GoogleGenAI; - describe('GeminiChat', () => { let chat: GeminiChat; const model = 'gemini-pro'; @@ -35,7 +25,7 @@ describe('GeminiChat', () => { beforeEach(() => { vi.clearAllMocks(); // Reset history for each test by creating a new instance - chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, []); + chat = new GeminiChat(mockModelsModule, model, config, []); }); afterEach(() => { @@ -129,19 +119,8 @@ describe('GeminiChat', () => { // @ts-expect-error Accessing private method for testing purposes chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed - // const history = chat.getHistory(); // Removed unused variable to satisfy linter - // The recordHistory will push the *new* userInput first, then the consolidated newModelOutput. - // However, the consolidation logic for *outputContents* itself should run, and then the merge with *existing* history. - // Let's adjust the test to reflect how recordHistory is used: it adds the current userInput, then the model's response to it. - // Reset and set up a more realistic scenario for merging with existing history - chat = new GeminiChat( - mockGoogleGenAI, - mockModelsModule, - model, - config, - [], - ); + chat = new GeminiChat(mockModelsModule, model, config, []); const firstUserInput: Content = { role: 'user', parts: [{ text: 'First user input' }], @@ -184,7 +163,7 @@ describe('GeminiChat', () => { role: 'model', parts: [{ text: 'Initial model answer.' }], }; - chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, [ + chat = new GeminiChat(mockModelsModule, model, config, [ initialUser, initialModel, ]); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index b4844499..54f74102 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -10,15 +10,14 @@ import { GenerateContentResponse, Content, - Models, GenerateContentConfig, SendMessageParameters, - GoogleGenAI, createUserContent, Part, } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; +import { ContentGenerator } from './contentGenerator.js'; /** * Returns true if the response is valid, false otherwise. @@ -120,8 +119,7 @@ export class GeminiChat { private sendPromise: Promise = Promise.resolve(); constructor( - private readonly apiClient: GoogleGenAI, - private readonly modelsModule: Models, + private readonly contentGenerator: ContentGenerator, private readonly model: string, private readonly config: GenerateContentConfig = {}, private history: Content[] = [], @@ -156,7 +154,7 @@ export class GeminiChat { const userContent = createUserContent(params.message); const apiCall = () => - this.modelsModule.generateContent({ + this.contentGenerator.generateContent({ model: this.model, contents: this.getHistory(true).concat(userContent), config: { ...this.config, ...params.config }, @@ -225,7 +223,7 @@ export class GeminiChat { const userContent = createUserContent(params.message); const apiCall = () => - this.modelsModule.generateContentStream({ + this.contentGenerator.generateContentStream({ model: this.model, contents: this.getHistory(true).concat(userContent), config: { ...this.config, ...params.config }, diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts index 872e00f6..2514c99d 100644 --- a/packages/core/src/utils/nextSpeakerChecker.test.ts +++ b/packages/core/src/utils/nextSpeakerChecker.test.ts @@ -69,7 +69,6 @@ describe('checkNextSpeaker', () => { // GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor chatInstance = new GeminiChat( - mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel 'gemini-pro', // model name {},