Introduce generate content interface (#755)

This commit is contained in:
Tommaso Sciortino 2025-06-05 13:26:38 -07:00 committed by GitHub
parent 2ebf2fbc82
commit e59e18251b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 45 additions and 42 deletions

View File

@ -32,10 +32,11 @@ import {
logApiResponse,
logApiError,
} from '../telemetry/index.js';
import { ContentGenerator } from './contentGenerator.js';
export class GeminiClient {
private chat: Promise<GeminiChat>;
private client: GoogleGenAI;
private contentGenerator: ContentGenerator;
private model: string;
private generateContentConfig: GenerateContentConfig = {
temperature: 0,
@ -48,7 +49,7 @@ export class GeminiClient {
const apiKeyFromConfig = config.getApiKey();
const vertexaiFlag = config.getVertexAI();
this.client = new GoogleGenAI({
const googleGenAI = new GoogleGenAI({
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
vertexai: vertexaiFlag,
httpOptions: {
@ -57,6 +58,7 @@ export class GeminiClient {
},
},
});
this.contentGenerator = googleGenAI.models;
this.model = config.getModel();
this.chat = this.startChat();
}
@ -148,8 +150,7 @@ export class GeminiClient {
const systemInstruction = getCoreSystemPrompt(userMemory);
return new GeminiChat(
this.client,
this.client.models,
this.contentGenerator,
this.model,
{
systemInstruction,
@ -285,7 +286,7 @@ export class GeminiClient {
let inputTokenCount = 0;
try {
const { totalTokens } = await this.client.models.countTokens({
const { totalTokens } = await this.contentGenerator.countTokens({
model,
contents,
});
@ -300,7 +301,7 @@ export class GeminiClient {
this._logApiRequest(model, inputTokenCount);
const apiCall = () =>
this.client.models.generateContent({
this.contentGenerator.generateContent({
model,
config: {
...requestConfig,
@ -400,7 +401,7 @@ export class GeminiClient {
let inputTokenCount = 0;
try {
const { totalTokens } = await this.client.models.countTokens({
const { totalTokens } = await this.contentGenerator.countTokens({
model: modelToUse,
contents,
});
@ -415,7 +416,7 @@ export class GeminiClient {
this._logApiRequest(modelToUse, inputTokenCount);
const apiCall = () =>
this.client.models.generateContent({
this.contentGenerator.generateContent({
model: modelToUse,
config: requestConfig,
contents,
@ -453,8 +454,7 @@ export class GeminiClient {
const chat = await this.chat;
const history = chat.getHistory(true); // Get curated history
// Count tokens using the models module from the GoogleGenAI client instance
const { totalTokens } = await this.client.models.countTokens({
const { totalTokens } = await this.contentGenerator.countTokens({
model: this.model,
contents: history,
});

View File

@ -0,0 +1,27 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
CountTokensResponse,
GenerateContentResponse,
GenerateContentParameters,
CountTokensParameters,
} from '@google/genai';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
*/
export interface ContentGenerator {
generateContent(
request: GenerateContentParameters,
): Promise<GenerateContentResponse>;
generateContentStream(
request: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>>;
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
}

View File

@ -5,13 +5,7 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
Content,
GoogleGenAI,
Models,
GenerateContentConfig,
Part,
} from '@google/genai';
import { Content, Models, GenerateContentConfig, Part } from '@google/genai';
import { GeminiChat } from './geminiChat.js';
// Mocks
@ -23,10 +17,6 @@ const mockModelsModule = {
batchEmbedContents: vi.fn(),
} as unknown as Models;
const mockGoogleGenAI = {
getGenerativeModel: vi.fn().mockReturnValue(mockModelsModule),
} as unknown as GoogleGenAI;
describe('GeminiChat', () => {
let chat: GeminiChat;
const model = 'gemini-pro';
@ -35,7 +25,7 @@ describe('GeminiChat', () => {
beforeEach(() => {
vi.clearAllMocks();
// Reset history for each test by creating a new instance
chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, []);
chat = new GeminiChat(mockModelsModule, model, config, []);
});
afterEach(() => {
@ -129,19 +119,8 @@ describe('GeminiChat', () => {
// @ts-expect-error Accessing private method for testing purposes
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
// const history = chat.getHistory(); // Removed unused variable to satisfy linter
// The recordHistory will push the *new* userInput first, then the consolidated newModelOutput.
// However, the consolidation logic for *outputContents* itself should run, and then the merge with *existing* history.
// Let's adjust the test to reflect how recordHistory is used: it adds the current userInput, then the model's response to it.
// Reset and set up a more realistic scenario for merging with existing history
chat = new GeminiChat(
mockGoogleGenAI,
mockModelsModule,
model,
config,
[],
);
chat = new GeminiChat(mockModelsModule, model, config, []);
const firstUserInput: Content = {
role: 'user',
parts: [{ text: 'First user input' }],
@ -184,7 +163,7 @@ describe('GeminiChat', () => {
role: 'model',
parts: [{ text: 'Initial model answer.' }],
};
chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, [
chat = new GeminiChat(mockModelsModule, model, config, [
initialUser,
initialModel,
]);

View File

@ -10,15 +10,14 @@
import {
GenerateContentResponse,
Content,
Models,
GenerateContentConfig,
SendMessageParameters,
GoogleGenAI,
createUserContent,
Part,
} from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
import { ContentGenerator } from './contentGenerator.js';
/**
* Returns true if the response is valid, false otherwise.
@ -120,8 +119,7 @@ export class GeminiChat {
private sendPromise: Promise<void> = Promise.resolve();
constructor(
private readonly apiClient: GoogleGenAI,
private readonly modelsModule: Models,
private readonly contentGenerator: ContentGenerator,
private readonly model: string,
private readonly config: GenerateContentConfig = {},
private history: Content[] = [],
@ -156,7 +154,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message);
const apiCall = () =>
this.modelsModule.generateContent({
this.contentGenerator.generateContent({
model: this.model,
contents: this.getHistory(true).concat(userContent),
config: { ...this.config, ...params.config },
@ -225,7 +223,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message);
const apiCall = () =>
this.modelsModule.generateContentStream({
this.contentGenerator.generateContentStream({
model: this.model,
contents: this.getHistory(true).concat(userContent),
config: { ...this.config, ...params.config },

View File

@ -69,7 +69,6 @@ describe('checkNextSpeaker', () => {
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
chatInstance = new GeminiChat(
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
'gemini-pro', // model name
{},