Introduce generate content interface (#755)
This commit is contained in:
parent
2ebf2fbc82
commit
e59e18251b
|
@ -32,10 +32,11 @@ import {
|
||||||
logApiResponse,
|
logApiResponse,
|
||||||
logApiError,
|
logApiError,
|
||||||
} from '../telemetry/index.js';
|
} from '../telemetry/index.js';
|
||||||
|
import { ContentGenerator } from './contentGenerator.js';
|
||||||
|
|
||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
private chat: Promise<GeminiChat>;
|
private chat: Promise<GeminiChat>;
|
||||||
private client: GoogleGenAI;
|
private contentGenerator: ContentGenerator;
|
||||||
private model: string;
|
private model: string;
|
||||||
private generateContentConfig: GenerateContentConfig = {
|
private generateContentConfig: GenerateContentConfig = {
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
|
@ -48,7 +49,7 @@ export class GeminiClient {
|
||||||
const apiKeyFromConfig = config.getApiKey();
|
const apiKeyFromConfig = config.getApiKey();
|
||||||
const vertexaiFlag = config.getVertexAI();
|
const vertexaiFlag = config.getVertexAI();
|
||||||
|
|
||||||
this.client = new GoogleGenAI({
|
const googleGenAI = new GoogleGenAI({
|
||||||
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
|
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
|
||||||
vertexai: vertexaiFlag,
|
vertexai: vertexaiFlag,
|
||||||
httpOptions: {
|
httpOptions: {
|
||||||
|
@ -57,6 +58,7 @@ export class GeminiClient {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
this.contentGenerator = googleGenAI.models;
|
||||||
this.model = config.getModel();
|
this.model = config.getModel();
|
||||||
this.chat = this.startChat();
|
this.chat = this.startChat();
|
||||||
}
|
}
|
||||||
|
@ -148,8 +150,7 @@ export class GeminiClient {
|
||||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||||
|
|
||||||
return new GeminiChat(
|
return new GeminiChat(
|
||||||
this.client,
|
this.contentGenerator,
|
||||||
this.client.models,
|
|
||||||
this.model,
|
this.model,
|
||||||
{
|
{
|
||||||
systemInstruction,
|
systemInstruction,
|
||||||
|
@ -285,7 +286,7 @@ export class GeminiClient {
|
||||||
|
|
||||||
let inputTokenCount = 0;
|
let inputTokenCount = 0;
|
||||||
try {
|
try {
|
||||||
const { totalTokens } = await this.client.models.countTokens({
|
const { totalTokens } = await this.contentGenerator.countTokens({
|
||||||
model,
|
model,
|
||||||
contents,
|
contents,
|
||||||
});
|
});
|
||||||
|
@ -300,7 +301,7 @@ export class GeminiClient {
|
||||||
this._logApiRequest(model, inputTokenCount);
|
this._logApiRequest(model, inputTokenCount);
|
||||||
|
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.client.models.generateContent({
|
this.contentGenerator.generateContent({
|
||||||
model,
|
model,
|
||||||
config: {
|
config: {
|
||||||
...requestConfig,
|
...requestConfig,
|
||||||
|
@ -400,7 +401,7 @@ export class GeminiClient {
|
||||||
|
|
||||||
let inputTokenCount = 0;
|
let inputTokenCount = 0;
|
||||||
try {
|
try {
|
||||||
const { totalTokens } = await this.client.models.countTokens({
|
const { totalTokens } = await this.contentGenerator.countTokens({
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
contents,
|
contents,
|
||||||
});
|
});
|
||||||
|
@ -415,7 +416,7 @@ export class GeminiClient {
|
||||||
this._logApiRequest(modelToUse, inputTokenCount);
|
this._logApiRequest(modelToUse, inputTokenCount);
|
||||||
|
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.client.models.generateContent({
|
this.contentGenerator.generateContent({
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
config: requestConfig,
|
config: requestConfig,
|
||||||
contents,
|
contents,
|
||||||
|
@ -453,8 +454,7 @@ export class GeminiClient {
|
||||||
const chat = await this.chat;
|
const chat = await this.chat;
|
||||||
const history = chat.getHistory(true); // Get curated history
|
const history = chat.getHistory(true); // Get curated history
|
||||||
|
|
||||||
// Count tokens using the models module from the GoogleGenAI client instance
|
const { totalTokens } = await this.contentGenerator.countTokens({
|
||||||
const { totalTokens } = await this.client.models.countTokens({
|
|
||||||
model: this.model,
|
model: this.model,
|
||||||
contents: history,
|
contents: history,
|
||||||
});
|
});
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
CountTokensResponse,
|
||||||
|
GenerateContentResponse,
|
||||||
|
GenerateContentParameters,
|
||||||
|
CountTokensParameters,
|
||||||
|
} from '@google/genai';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface abstracting the core functionalities for generating content and counting tokens.
|
||||||
|
*/
|
||||||
|
export interface ContentGenerator {
|
||||||
|
generateContent(
|
||||||
|
request: GenerateContentParameters,
|
||||||
|
): Promise<GenerateContentResponse>;
|
||||||
|
|
||||||
|
generateContentStream(
|
||||||
|
request: GenerateContentParameters,
|
||||||
|
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
||||||
|
|
||||||
|
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
||||||
|
}
|
|
@ -5,13 +5,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||||
import {
|
import { Content, Models, GenerateContentConfig, Part } from '@google/genai';
|
||||||
Content,
|
|
||||||
GoogleGenAI,
|
|
||||||
Models,
|
|
||||||
GenerateContentConfig,
|
|
||||||
Part,
|
|
||||||
} from '@google/genai';
|
|
||||||
import { GeminiChat } from './geminiChat.js';
|
import { GeminiChat } from './geminiChat.js';
|
||||||
|
|
||||||
// Mocks
|
// Mocks
|
||||||
|
@ -23,10 +17,6 @@ const mockModelsModule = {
|
||||||
batchEmbedContents: vi.fn(),
|
batchEmbedContents: vi.fn(),
|
||||||
} as unknown as Models;
|
} as unknown as Models;
|
||||||
|
|
||||||
const mockGoogleGenAI = {
|
|
||||||
getGenerativeModel: vi.fn().mockReturnValue(mockModelsModule),
|
|
||||||
} as unknown as GoogleGenAI;
|
|
||||||
|
|
||||||
describe('GeminiChat', () => {
|
describe('GeminiChat', () => {
|
||||||
let chat: GeminiChat;
|
let chat: GeminiChat;
|
||||||
const model = 'gemini-pro';
|
const model = 'gemini-pro';
|
||||||
|
@ -35,7 +25,7 @@ describe('GeminiChat', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
// Reset history for each test by creating a new instance
|
// Reset history for each test by creating a new instance
|
||||||
chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, []);
|
chat = new GeminiChat(mockModelsModule, model, config, []);
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
@ -129,19 +119,8 @@ describe('GeminiChat', () => {
|
||||||
// @ts-expect-error Accessing private method for testing purposes
|
// @ts-expect-error Accessing private method for testing purposes
|
||||||
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
|
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
|
||||||
|
|
||||||
// const history = chat.getHistory(); // Removed unused variable to satisfy linter
|
|
||||||
// The recordHistory will push the *new* userInput first, then the consolidated newModelOutput.
|
|
||||||
// However, the consolidation logic for *outputContents* itself should run, and then the merge with *existing* history.
|
|
||||||
// Let's adjust the test to reflect how recordHistory is used: it adds the current userInput, then the model's response to it.
|
|
||||||
|
|
||||||
// Reset and set up a more realistic scenario for merging with existing history
|
// Reset and set up a more realistic scenario for merging with existing history
|
||||||
chat = new GeminiChat(
|
chat = new GeminiChat(mockModelsModule, model, config, []);
|
||||||
mockGoogleGenAI,
|
|
||||||
mockModelsModule,
|
|
||||||
model,
|
|
||||||
config,
|
|
||||||
[],
|
|
||||||
);
|
|
||||||
const firstUserInput: Content = {
|
const firstUserInput: Content = {
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: [{ text: 'First user input' }],
|
parts: [{ text: 'First user input' }],
|
||||||
|
@ -184,7 +163,7 @@ describe('GeminiChat', () => {
|
||||||
role: 'model',
|
role: 'model',
|
||||||
parts: [{ text: 'Initial model answer.' }],
|
parts: [{ text: 'Initial model answer.' }],
|
||||||
};
|
};
|
||||||
chat = new GeminiChat(mockGoogleGenAI, mockModelsModule, model, config, [
|
chat = new GeminiChat(mockModelsModule, model, config, [
|
||||||
initialUser,
|
initialUser,
|
||||||
initialModel,
|
initialModel,
|
||||||
]);
|
]);
|
||||||
|
|
|
@ -10,15 +10,14 @@
|
||||||
import {
|
import {
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
Content,
|
Content,
|
||||||
Models,
|
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
SendMessageParameters,
|
SendMessageParameters,
|
||||||
GoogleGenAI,
|
|
||||||
createUserContent,
|
createUserContent,
|
||||||
Part,
|
Part,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { retryWithBackoff } from '../utils/retry.js';
|
import { retryWithBackoff } from '../utils/retry.js';
|
||||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||||
|
import { ContentGenerator } from './contentGenerator.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if the response is valid, false otherwise.
|
* Returns true if the response is valid, false otherwise.
|
||||||
|
@ -120,8 +119,7 @@ export class GeminiChat {
|
||||||
private sendPromise: Promise<void> = Promise.resolve();
|
private sendPromise: Promise<void> = Promise.resolve();
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
private readonly apiClient: GoogleGenAI,
|
private readonly contentGenerator: ContentGenerator,
|
||||||
private readonly modelsModule: Models,
|
|
||||||
private readonly model: string,
|
private readonly model: string,
|
||||||
private readonly config: GenerateContentConfig = {},
|
private readonly config: GenerateContentConfig = {},
|
||||||
private history: Content[] = [],
|
private history: Content[] = [],
|
||||||
|
@ -156,7 +154,7 @@ export class GeminiChat {
|
||||||
const userContent = createUserContent(params.message);
|
const userContent = createUserContent(params.message);
|
||||||
|
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.modelsModule.generateContent({
|
this.contentGenerator.generateContent({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
contents: this.getHistory(true).concat(userContent),
|
contents: this.getHistory(true).concat(userContent),
|
||||||
config: { ...this.config, ...params.config },
|
config: { ...this.config, ...params.config },
|
||||||
|
@ -225,7 +223,7 @@ export class GeminiChat {
|
||||||
const userContent = createUserContent(params.message);
|
const userContent = createUserContent(params.message);
|
||||||
|
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.modelsModule.generateContentStream({
|
this.contentGenerator.generateContentStream({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
contents: this.getHistory(true).concat(userContent),
|
contents: this.getHistory(true).concat(userContent),
|
||||||
config: { ...this.config, ...params.config },
|
config: { ...this.config, ...params.config },
|
||||||
|
|
|
@ -69,7 +69,6 @@ describe('checkNextSpeaker', () => {
|
||||||
|
|
||||||
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
|
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
|
||||||
chatInstance = new GeminiChat(
|
chatInstance = new GeminiChat(
|
||||||
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
|
|
||||||
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
||||||
'gemini-pro', // model name
|
'gemini-pro', // model name
|
||||||
{},
|
{},
|
||||||
|
|
Loading…
Reference in New Issue