Introduce generate content interface (#755)

This commit is contained in:
Tommaso Sciortino 2025-06-05 13:26:38 -07:00 committed by GitHub
parent 2ebf2fbc82
commit e59e18251b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 45 additions and 42 deletions

View File

@ -32,10 +32,11 @@ import {
logApiResponse, logApiResponse,
logApiError, logApiError,
} from '../telemetry/index.js'; } from '../telemetry/index.js';
import { ContentGenerator } from './contentGenerator.js';
export class GeminiClient { export class GeminiClient {
private chat: Promise<GeminiChat>; private chat: Promise<GeminiChat>;
private client: GoogleGenAI; private contentGenerator: ContentGenerator;
private model: string; private model: string;
private generateContentConfig: GenerateContentConfig = { private generateContentConfig: GenerateContentConfig = {
temperature: 0, temperature: 0,
@ -48,7 +49,7 @@ export class GeminiClient {
const apiKeyFromConfig = config.getApiKey(); const apiKeyFromConfig = config.getApiKey();
const vertexaiFlag = config.getVertexAI(); const vertexaiFlag = config.getVertexAI();
this.client = new GoogleGenAI({ const googleGenAI = new GoogleGenAI({
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig, apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
vertexai: vertexaiFlag, vertexai: vertexaiFlag,
httpOptions: { httpOptions: {
@ -57,6 +58,7 @@ export class GeminiClient {
}, },
}, },
}); });
this.contentGenerator = googleGenAI.models;
this.model = config.getModel(); this.model = config.getModel();
this.chat = this.startChat(); this.chat = this.startChat();
} }
@ -148,8 +150,7 @@ export class GeminiClient {
const systemInstruction = getCoreSystemPrompt(userMemory); const systemInstruction = getCoreSystemPrompt(userMemory);
return new GeminiChat( return new GeminiChat(
this.client, this.contentGenerator,
this.client.models,
this.model, this.model,
{ {
systemInstruction, systemInstruction,
@ -285,7 +286,7 @@ export class GeminiClient {
let inputTokenCount = 0; let inputTokenCount = 0;
try { try {
const { totalTokens } = await this.client.models.countTokens({ const { totalTokens } = await this.contentGenerator.countTokens({
model, model,
contents, contents,
}); });
@ -300,7 +301,7 @@ export class GeminiClient {
this._logApiRequest(model, inputTokenCount); this._logApiRequest(model, inputTokenCount);
const apiCall = () => const apiCall = () =>
this.client.models.generateContent({ this.contentGenerator.generateContent({
model, model,
config: { config: {
...requestConfig, ...requestConfig,
@ -400,7 +401,7 @@ export class GeminiClient {
let inputTokenCount = 0; let inputTokenCount = 0;
try { try {
const { totalTokens } = await this.client.models.countTokens({ const { totalTokens } = await this.contentGenerator.countTokens({
model: modelToUse, model: modelToUse,
contents, contents,
}); });
@ -415,7 +416,7 @@ export class GeminiClient {
this._logApiRequest(modelToUse, inputTokenCount); this._logApiRequest(modelToUse, inputTokenCount);
const apiCall = () => const apiCall = () =>
this.client.models.generateContent({ this.contentGenerator.generateContent({
model: modelToUse, model: modelToUse,
config: requestConfig, config: requestConfig,
contents, contents,
@ -453,8 +454,7 @@ export class GeminiClient {
const chat = await this.chat; const chat = await this.chat;
const history = chat.getHistory(true); // Get curated history const history = chat.getHistory(true); // Get curated history
// Count tokens using the models module from the GoogleGenAI client instance const { totalTokens } = await this.contentGenerator.countTokens({
const { totalTokens } = await this.client.models.countTokens({
model: this.model, model: this.model,
contents: history, contents: history,
}); });

View File

@ -0,0 +1,27 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
CountTokensResponse,
GenerateContentResponse,
GenerateContentParameters,
CountTokensParameters,
} from '@google/genai';
/**
* Interface abstracting the core functionalities for generating content and counting tokens.
*/
export interface ContentGenerator {
generateContent(
request: GenerateContentParameters,
): Promise<GenerateContentResponse>;
generateContentStream(
request: GenerateContentParameters,
): Promise<AsyncGenerator<GenerateContentResponse>>;
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
}

View File

@ -5,13 +5,7 @@
*/ */
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { import { Content, Models, GenerateContentConfig, Part } from '@google/genai';
Content,
GoogleGenAI,
Models,
GenerateContentConfig,
Part,
} from '@google/genai';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat } from './geminiChat.js';
// Mocks // Mocks
@ -23,10 +17,6 @@ const mockModelsModule = {
batchEmbedContents: vi.fn(), batchEmbedContents: vi.fn(),
} as unknown as Models; } as unknown as Models;
const mockGoogleGenAI = {
getGenerativeModel: vi.fn().mockReturnValue(mockModelsModule),
} as unknown as GoogleGenAI;
describe('GeminiChat', () => { describe('GeminiChat', () => {
let chat: GeminiChat; let chat: GeminiChat;
const model = 'gemini-pro'; const model = 'gemini-pro';
@ -35,7 +25,7 @@ describe('GeminiChat', () => {
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
// Reset history for each test by creating a new instance // Reset history for each test by creating a new instance
chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, []); chat = new GeminiChat(mockModelsModule, model, config, []);
}); });
afterEach(() => { afterEach(() => {
@ -129,19 +119,8 @@ describe('GeminiChat', () => {
// @ts-expect-error Accessing private method for testing purposes // @ts-expect-error Accessing private method for testing purposes
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
// const history = chat.getHistory(); // Removed unused variable to satisfy linter
// The recordHistory will push the *new* userInput first, then the consolidated newModelOutput.
// However, the consolidation logic for *outputContents* itself should run, and then the merge with *existing* history.
// Let's adjust the test to reflect how recordHistory is used: it adds the current userInput, then the model's response to it.
// Reset and set up a more realistic scenario for merging with existing history // Reset and set up a more realistic scenario for merging with existing history
chat = new GeminiChat( chat = new GeminiChat(mockModelsModule, model, config, []);
mockGoogleGenAI,
mockModelsModule,
model,
config,
[],
);
const firstUserInput: Content = { const firstUserInput: Content = {
role: 'user', role: 'user',
parts: [{ text: 'First user input' }], parts: [{ text: 'First user input' }],
@ -184,7 +163,7 @@ describe('GeminiChat', () => {
role: 'model', role: 'model',
parts: [{ text: 'Initial model answer.' }], parts: [{ text: 'Initial model answer.' }],
}; };
chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, [ chat = new GeminiChat(mockModelsModule, model, config, [
initialUser, initialUser,
initialModel, initialModel,
]); ]);

View File

@ -10,15 +10,14 @@
import { import {
GenerateContentResponse, GenerateContentResponse,
Content, Content,
Models,
GenerateContentConfig, GenerateContentConfig,
SendMessageParameters, SendMessageParameters,
GoogleGenAI,
createUserContent, createUserContent,
Part, Part,
} from '@google/genai'; } from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js'; import { retryWithBackoff } from '../utils/retry.js';
import { isFunctionResponse } from '../utils/messageInspectors.js'; import { isFunctionResponse } from '../utils/messageInspectors.js';
import { ContentGenerator } from './contentGenerator.js';
/** /**
* Returns true if the response is valid, false otherwise. * Returns true if the response is valid, false otherwise.
@ -120,8 +119,7 @@ export class GeminiChat {
private sendPromise: Promise<void> = Promise.resolve(); private sendPromise: Promise<void> = Promise.resolve();
constructor( constructor(
private readonly apiClient: GoogleGenAI, private readonly contentGenerator: ContentGenerator,
private readonly modelsModule: Models,
private readonly model: string, private readonly model: string,
private readonly config: GenerateContentConfig = {}, private readonly config: GenerateContentConfig = {},
private history: Content[] = [], private history: Content[] = [],
@ -156,7 +154,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message); const userContent = createUserContent(params.message);
const apiCall = () => const apiCall = () =>
this.modelsModule.generateContent({ this.contentGenerator.generateContent({
model: this.model, model: this.model,
contents: this.getHistory(true).concat(userContent), contents: this.getHistory(true).concat(userContent),
config: { ...this.config, ...params.config }, config: { ...this.config, ...params.config },
@ -225,7 +223,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message); const userContent = createUserContent(params.message);
const apiCall = () => const apiCall = () =>
this.modelsModule.generateContentStream({ this.contentGenerator.generateContentStream({
model: this.model, model: this.model,
contents: this.getHistory(true).concat(userContent), contents: this.getHistory(true).concat(userContent),
config: { ...this.config, ...params.config }, config: { ...this.config, ...params.config },

View File

@ -69,7 +69,6 @@ describe('checkNextSpeaker', () => {
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor // GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
chatInstance = new GeminiChat( chatInstance = new GeminiChat(
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
'gemini-pro', // model name 'gemini-pro', // model name
{}, {},