Introduce generate content interface (#755)
This commit is contained in:
parent
2ebf2fbc82
commit
e59e18251b
|
@ -32,10 +32,11 @@ import {
|
|||
logApiResponse,
|
||||
logApiError,
|
||||
} from '../telemetry/index.js';
|
||||
import { ContentGenerator } from './contentGenerator.js';
|
||||
|
||||
export class GeminiClient {
|
||||
private chat: Promise<GeminiChat>;
|
||||
private client: GoogleGenAI;
|
||||
private contentGenerator: ContentGenerator;
|
||||
private model: string;
|
||||
private generateContentConfig: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
|
@ -48,7 +49,7 @@ export class GeminiClient {
|
|||
const apiKeyFromConfig = config.getApiKey();
|
||||
const vertexaiFlag = config.getVertexAI();
|
||||
|
||||
this.client = new GoogleGenAI({
|
||||
const googleGenAI = new GoogleGenAI({
|
||||
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
|
||||
vertexai: vertexaiFlag,
|
||||
httpOptions: {
|
||||
|
@ -57,6 +58,7 @@ export class GeminiClient {
|
|||
},
|
||||
},
|
||||
});
|
||||
this.contentGenerator = googleGenAI.models;
|
||||
this.model = config.getModel();
|
||||
this.chat = this.startChat();
|
||||
}
|
||||
|
@ -148,8 +150,7 @@ export class GeminiClient {
|
|||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||
|
||||
return new GeminiChat(
|
||||
this.client,
|
||||
this.client.models,
|
||||
this.contentGenerator,
|
||||
this.model,
|
||||
{
|
||||
systemInstruction,
|
||||
|
@ -285,7 +286,7 @@ export class GeminiClient {
|
|||
|
||||
let inputTokenCount = 0;
|
||||
try {
|
||||
const { totalTokens } = await this.client.models.countTokens({
|
||||
const { totalTokens } = await this.contentGenerator.countTokens({
|
||||
model,
|
||||
contents,
|
||||
});
|
||||
|
@ -300,7 +301,7 @@ export class GeminiClient {
|
|||
this._logApiRequest(model, inputTokenCount);
|
||||
|
||||
const apiCall = () =>
|
||||
this.client.models.generateContent({
|
||||
this.contentGenerator.generateContent({
|
||||
model,
|
||||
config: {
|
||||
...requestConfig,
|
||||
|
@ -400,7 +401,7 @@ export class GeminiClient {
|
|||
|
||||
let inputTokenCount = 0;
|
||||
try {
|
||||
const { totalTokens } = await this.client.models.countTokens({
|
||||
const { totalTokens } = await this.contentGenerator.countTokens({
|
||||
model: modelToUse,
|
||||
contents,
|
||||
});
|
||||
|
@ -415,7 +416,7 @@ export class GeminiClient {
|
|||
this._logApiRequest(modelToUse, inputTokenCount);
|
||||
|
||||
const apiCall = () =>
|
||||
this.client.models.generateContent({
|
||||
this.contentGenerator.generateContent({
|
||||
model: modelToUse,
|
||||
config: requestConfig,
|
||||
contents,
|
||||
|
@ -453,8 +454,7 @@ export class GeminiClient {
|
|||
const chat = await this.chat;
|
||||
const history = chat.getHistory(true); // Get curated history
|
||||
|
||||
// Count tokens using the models module from the GoogleGenAI client instance
|
||||
const { totalTokens } = await this.client.models.countTokens({
|
||||
const { totalTokens } = await this.contentGenerator.countTokens({
|
||||
model: this.model,
|
||||
contents: history,
|
||||
});
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
CountTokensResponse,
|
||||
GenerateContentResponse,
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
} from '@google/genai';
|
||||
|
||||
/**
|
||||
* Interface abstracting the core functionalities for generating content and counting tokens.
|
||||
*/
|
||||
export interface ContentGenerator {
|
||||
generateContent(
|
||||
request: GenerateContentParameters,
|
||||
): Promise<GenerateContentResponse>;
|
||||
|
||||
generateContentStream(
|
||||
request: GenerateContentParameters,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
||||
|
||||
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
||||
}
|
|
@ -5,13 +5,7 @@
|
|||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
Content,
|
||||
GoogleGenAI,
|
||||
Models,
|
||||
GenerateContentConfig,
|
||||
Part,
|
||||
} from '@google/genai';
|
||||
import { Content, Models, GenerateContentConfig, Part } from '@google/genai';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
|
||||
// Mocks
|
||||
|
@ -23,10 +17,6 @@ const mockModelsModule = {
|
|||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as Models;
|
||||
|
||||
const mockGoogleGenAI = {
|
||||
getGenerativeModel: vi.fn().mockReturnValue(mockModelsModule),
|
||||
} as unknown as GoogleGenAI;
|
||||
|
||||
describe('GeminiChat', () => {
|
||||
let chat: GeminiChat;
|
||||
const model = 'gemini-pro';
|
||||
|
@ -35,7 +25,7 @@ describe('GeminiChat', () => {
|
|||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Reset history for each test by creating a new instance
|
||||
chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, []);
|
||||
chat = new GeminiChat(mockModelsModule, model, config, []);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
@ -129,19 +119,8 @@ describe('GeminiChat', () => {
|
|||
// @ts-expect-error Accessing private method for testing purposes
|
||||
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
|
||||
|
||||
// const history = chat.getHistory(); // Removed unused variable to satisfy linter
|
||||
// The recordHistory will push the *new* userInput first, then the consolidated newModelOutput.
|
||||
// However, the consolidation logic for *outputContents* itself should run, and then the merge with *existing* history.
|
||||
// Let's adjust the test to reflect how recordHistory is used: it adds the current userInput, then the model's response to it.
|
||||
|
||||
// Reset and set up a more realistic scenario for merging with existing history
|
||||
chat = new GeminiChat(
|
||||
mockGoogleGenAI,
|
||||
mockModelsModule,
|
||||
model,
|
||||
config,
|
||||
[],
|
||||
);
|
||||
chat = new GeminiChat(mockModelsModule, model, config, []);
|
||||
const firstUserInput: Content = {
|
||||
role: 'user',
|
||||
parts: [{ text: 'First user input' }],
|
||||
|
@ -184,7 +163,7 @@ describe('GeminiChat', () => {
|
|||
role: 'model',
|
||||
parts: [{ text: 'Initial model answer.' }],
|
||||
};
|
||||
chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, [
|
||||
chat = new GeminiChat(mockModelsModule, model, config, [
|
||||
initialUser,
|
||||
initialModel,
|
||||
]);
|
||||
|
|
|
@ -10,15 +10,14 @@
|
|||
import {
|
||||
GenerateContentResponse,
|
||||
Content,
|
||||
Models,
|
||||
GenerateContentConfig,
|
||||
SendMessageParameters,
|
||||
GoogleGenAI,
|
||||
createUserContent,
|
||||
Part,
|
||||
} from '@google/genai';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { ContentGenerator } from './contentGenerator.js';
|
||||
|
||||
/**
|
||||
* Returns true if the response is valid, false otherwise.
|
||||
|
@ -120,8 +119,7 @@ export class GeminiChat {
|
|||
private sendPromise: Promise<void> = Promise.resolve();
|
||||
|
||||
constructor(
|
||||
private readonly apiClient: GoogleGenAI,
|
||||
private readonly modelsModule: Models,
|
||||
private readonly contentGenerator: ContentGenerator,
|
||||
private readonly model: string,
|
||||
private readonly config: GenerateContentConfig = {},
|
||||
private history: Content[] = [],
|
||||
|
@ -156,7 +154,7 @@ export class GeminiChat {
|
|||
const userContent = createUserContent(params.message);
|
||||
|
||||
const apiCall = () =>
|
||||
this.modelsModule.generateContent({
|
||||
this.contentGenerator.generateContent({
|
||||
model: this.model,
|
||||
contents: this.getHistory(true).concat(userContent),
|
||||
config: { ...this.config, ...params.config },
|
||||
|
@ -225,7 +223,7 @@ export class GeminiChat {
|
|||
const userContent = createUserContent(params.message);
|
||||
|
||||
const apiCall = () =>
|
||||
this.modelsModule.generateContentStream({
|
||||
this.contentGenerator.generateContentStream({
|
||||
model: this.model,
|
||||
contents: this.getHistory(true).concat(userContent),
|
||||
config: { ...this.config, ...params.config },
|
||||
|
|
|
@ -69,7 +69,6 @@ describe('checkNextSpeaker', () => {
|
|||
|
||||
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
|
||||
chatInstance = new GeminiChat(
|
||||
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
|
||||
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
||||
'gemini-pro', // model name
|
||||
{},
|
||||
|
|
Loading…
Reference in New Issue