Add embedder (#818)
This commit is contained in:
parent
51cd5ffd91
commit
27fdd1b6e6
|
@ -33,6 +33,7 @@ const logger = {
|
|||
|
||||
export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-06-05';
|
||||
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash-preview-05-20';
|
||||
export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001';
|
||||
|
||||
interface CliArgs {
|
||||
model: string | undefined;
|
||||
|
@ -177,6 +178,7 @@ export async function loadCliConfig(
|
|||
const configParams: ConfigParameters = {
|
||||
apiKey: apiKeyForServer,
|
||||
model: modelToUse,
|
||||
embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
sandbox: argv.sandbox ?? settings.sandbox ?? argv.yolo ?? false,
|
||||
targetDir: process.cwd(),
|
||||
debugMode,
|
||||
|
|
|
@ -38,6 +38,7 @@ interface MockServerConfig {
|
|||
vertexai?: boolean;
|
||||
showMemoryUsage?: boolean;
|
||||
accessibility?: AccessibilitySettings;
|
||||
embeddingModel: string;
|
||||
|
||||
getApiKey: Mock<() => string>;
|
||||
getModel: Mock<() => string>;
|
||||
|
@ -92,6 +93,7 @@ vi.mock('@gemini-code/core', async (importOriginal) => {
|
|||
vertexai: opts.vertexai,
|
||||
showMemoryUsage: opts.showMemoryUsage ?? false,
|
||||
accessibility: opts.accessibility ?? {},
|
||||
embeddingModel: opts.embeddingModel || 'test-embedding-model',
|
||||
|
||||
getApiKey: vi.fn(() => opts.apiKey || 'test-key'),
|
||||
getModel: vi.fn(() => opts.model || 'test-model-in-mock-factory'),
|
||||
|
@ -178,7 +180,8 @@ describe('App UI', () => {
|
|||
const ServerConfigMocked = vi.mocked(ServerConfig, true);
|
||||
mockConfig = new ServerConfigMocked({
|
||||
apiKey: 'test-key',
|
||||
model: 'test-model-in-options',
|
||||
model: 'test-model',
|
||||
embeddingModel: 'test-embedding-model',
|
||||
sandbox: false,
|
||||
targetDir: '/test/dir',
|
||||
debugMode: false,
|
||||
|
|
|
@ -48,9 +48,11 @@ describe('Server Config (config.ts)', () => {
|
|||
const USER_AGENT = 'ServerTestAgent/1.0';
|
||||
const USER_MEMORY = 'Test User Memory';
|
||||
const TELEMETRY = false;
|
||||
const EMBEDDING_MODEL = 'gemini-embedding';
|
||||
const baseParams: ConfigParameters = {
|
||||
apiKey: API_KEY,
|
||||
model: MODEL,
|
||||
embeddingModel: EMBEDDING_MODEL,
|
||||
sandbox: SANDBOX,
|
||||
targetDir: TARGET_DIR,
|
||||
debugMode: DEBUG_MODE,
|
||||
|
|
|
@ -55,6 +55,7 @@ export class MCPServerConfig {
|
|||
export interface ConfigParameters {
|
||||
apiKey: string;
|
||||
model: string;
|
||||
embeddingModel: string;
|
||||
sandbox: boolean | string;
|
||||
targetDir: string;
|
||||
debugMode: boolean;
|
||||
|
@ -84,6 +85,7 @@ export class Config {
|
|||
private toolRegistry: Promise<ToolRegistry>;
|
||||
private readonly apiKey: string;
|
||||
private readonly model: string;
|
||||
private readonly embeddingModel: string;
|
||||
private readonly sandbox: boolean | string;
|
||||
private readonly targetDir: string;
|
||||
private readonly debugMode: boolean;
|
||||
|
@ -113,6 +115,7 @@ export class Config {
|
|||
constructor(params: ConfigParameters) {
|
||||
this.apiKey = params.apiKey;
|
||||
this.model = params.model;
|
||||
this.embeddingModel = params.embeddingModel;
|
||||
this.sandbox = params.sandbox;
|
||||
this.targetDir = path.resolve(params.targetDir);
|
||||
this.debugMode = params.debugMode;
|
||||
|
@ -163,6 +166,10 @@ export class Config {
|
|||
return this.model;
|
||||
}
|
||||
|
||||
getEmbeddingModel(): string {
|
||||
return this.embeddingModel;
|
||||
}
|
||||
|
||||
getSandbox(): boolean | string {
|
||||
return this.sandbox;
|
||||
}
|
||||
|
|
|
@ -6,29 +6,23 @@
|
|||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
import { Chat, GenerateContentResponse } from '@google/genai';
|
||||
import {
|
||||
Chat,
|
||||
EmbedContentResponse,
|
||||
GenerateContentResponse,
|
||||
GoogleGenAI,
|
||||
} from '@google/genai';
|
||||
import { GeminiClient } from './client.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
// --- Mocks ---
|
||||
const mockChatCreateFn = vi.fn();
|
||||
const mockGenerateContentFn = vi.fn();
|
||||
const mockEmbedContentFn = vi.fn();
|
||||
|
||||
vi.mock('@google/genai', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@google/genai')>();
|
||||
const MockedGoogleGenerativeAI = vi
|
||||
.fn()
|
||||
.mockImplementation((/*...args*/) => ({
|
||||
chats: { create: mockChatCreateFn },
|
||||
models: { generateContent: mockGenerateContentFn },
|
||||
}));
|
||||
return {
|
||||
...actual,
|
||||
GoogleGenerativeAI: MockedGoogleGenerativeAI,
|
||||
Chat: vi.fn(),
|
||||
Type: actual.Type ?? { OBJECT: 'OBJECT', STRING: 'STRING' },
|
||||
};
|
||||
});
|
||||
vi.mock('@google/genai');
|
||||
|
||||
vi.mock('../config/config');
|
||||
vi.mock('../config/config.js');
|
||||
vi.mock('./prompts');
|
||||
vi.mock('../utils/getFolderStructure', () => ({
|
||||
getFolderStructure: vi.fn().mockResolvedValue('Mock Folder Structure'),
|
||||
|
@ -44,8 +38,24 @@ vi.mock('../utils/generateContentResponseUtilities', () => ({
|
|||
}));
|
||||
|
||||
describe('Gemini Client (client.ts)', () => {
|
||||
let client: GeminiClient;
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
// Set up the mock for GoogleGenAI constructor and its methods
|
||||
const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
|
||||
MockedGoogleGenAI.mockImplementation(() => {
|
||||
const mock = {
|
||||
chats: { create: mockChatCreateFn },
|
||||
models: {
|
||||
generateContent: mockGenerateContentFn,
|
||||
embedContent: mockEmbedContentFn,
|
||||
},
|
||||
};
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return mock as any;
|
||||
});
|
||||
|
||||
mockChatCreateFn.mockResolvedValue({} as Chat);
|
||||
mockGenerateContentFn.mockResolvedValue({
|
||||
candidates: [
|
||||
|
@ -56,6 +66,35 @@ describe('Gemini Client (client.ts)', () => {
|
|||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
// Because the GeminiClient constructor kicks off an async process (startChat)
|
||||
// that depends on a fully-formed Config object, we need to mock the
|
||||
// entire implementation of Config for these tests.
|
||||
const mockToolRegistry = {
|
||||
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||
getTool: vi.fn().mockReturnValue(null),
|
||||
};
|
||||
const MockedConfig = vi.mocked(Config, true);
|
||||
MockedConfig.mockImplementation(() => {
|
||||
const mock = {
|
||||
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
|
||||
getModel: vi.fn().mockReturnValue('test-model'),
|
||||
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
|
||||
getApiKey: vi.fn().mockReturnValue('test-key'),
|
||||
getVertexAI: vi.fn().mockReturnValue(false),
|
||||
getUserAgent: vi.fn().mockReturnValue('test-agent'),
|
||||
getUserMemory: vi.fn().mockReturnValue(''),
|
||||
getFullContext: vi.fn().mockReturnValue(false),
|
||||
};
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return mock as any;
|
||||
});
|
||||
|
||||
// We can instantiate the client here since Config is mocked
|
||||
// and the constructor will use the mocked GoogleGenAI
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const mockConfig = new Config({} as any);
|
||||
client = new GeminiClient(mockConfig);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
@ -82,8 +121,97 @@ describe('Gemini Client (client.ts)', () => {
|
|||
// it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... });
|
||||
// it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
|
||||
|
||||
// Add a placeholder test to keep the suite valid
|
||||
it('should have a placeholder test', () => {
|
||||
expect(true).toBe(true);
|
||||
describe('generateEmbedding', () => {
|
||||
const texts = ['hello world', 'goodbye world'];
|
||||
const testEmbeddingModel = 'test-embedding-model';
|
||||
|
||||
it('should call embedContent with correct parameters and return embeddings', async () => {
|
||||
const mockEmbeddings = [
|
||||
[0.1, 0.2, 0.3],
|
||||
[0.4, 0.5, 0.6],
|
||||
];
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
embeddings: [
|
||||
{ values: mockEmbeddings[0] },
|
||||
{ values: mockEmbeddings[1] },
|
||||
],
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
|
||||
const result = await client.generateEmbedding(texts);
|
||||
|
||||
expect(mockEmbedContentFn).toHaveBeenCalledTimes(1);
|
||||
expect(mockEmbedContentFn).toHaveBeenCalledWith({
|
||||
model: testEmbeddingModel,
|
||||
contents: texts,
|
||||
});
|
||||
expect(result).toEqual(mockEmbeddings);
|
||||
});
|
||||
|
||||
it('should return an empty array if an empty array is passed', async () => {
|
||||
const result = await client.generateEmbedding([]);
|
||||
expect(result).toEqual([]);
|
||||
expect(mockEmbedContentFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw an error if API response has no embeddings array', async () => {
|
||||
mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'No embeddings found in API response.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if API response has an empty embeddings array', async () => {
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
embeddings: [],
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'No embeddings found in API response.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if API returns a mismatched number of embeddings', async () => {
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned a mismatched number of embeddings. Expected 2, got 1.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if any embedding has nullish values', async () => {
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned an empty embedding for input text at index 1: "goodbye world"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw an error if any embedding has an empty values array', async () => {
|
||||
const mockResponse: EmbedContentResponse = {
|
||||
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
|
||||
};
|
||||
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API returned an empty embedding for input text at index 0: "hello world"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should propagate errors from the API call', async () => {
|
||||
const apiError = new Error('API Failure');
|
||||
mockEmbedContentFn.mockRejectedValue(apiError);
|
||||
|
||||
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||
'API Failure',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
*/
|
||||
|
||||
import {
|
||||
EmbedContentResponse,
|
||||
EmbedContentParameters,
|
||||
GenerateContentConfig,
|
||||
GoogleGenAI,
|
||||
Part,
|
||||
|
@ -38,6 +40,7 @@ export class GeminiClient {
|
|||
private chat: Promise<GeminiChat>;
|
||||
private contentGenerator: ContentGenerator;
|
||||
private model: string;
|
||||
private embeddingModel: string;
|
||||
private generateContentConfig: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
|
@ -60,6 +63,7 @@ export class GeminiClient {
|
|||
});
|
||||
this.contentGenerator = googleGenAI.models;
|
||||
this.model = config.getModel();
|
||||
this.embeddingModel = config.getEmbeddingModel();
|
||||
this.chat = this.startChat();
|
||||
}
|
||||
|
||||
|
@ -450,6 +454,40 @@ export class GeminiClient {
|
|||
}
|
||||
}
|
||||
|
||||
async generateEmbedding(texts: string[]): Promise<number[][]> {
|
||||
if (!texts || texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const embedModelParams: EmbedContentParameters = {
|
||||
model: this.embeddingModel,
|
||||
contents: texts,
|
||||
};
|
||||
const embedContentResponse: EmbedContentResponse =
|
||||
await this.contentGenerator.embedContent(embedModelParams);
|
||||
if (
|
||||
!embedContentResponse.embeddings ||
|
||||
embedContentResponse.embeddings.length === 0
|
||||
) {
|
||||
throw new Error('No embeddings found in API response.');
|
||||
}
|
||||
|
||||
if (embedContentResponse.embeddings.length !== texts.length) {
|
||||
throw new Error(
|
||||
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return embedContentResponse.embeddings.map((embedding, index) => {
|
||||
const values = embedding.values;
|
||||
if (!values || values.length === 0) {
|
||||
throw new Error(
|
||||
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
|
||||
);
|
||||
}
|
||||
return values;
|
||||
});
|
||||
}
|
||||
|
||||
private async tryCompressChat(): Promise<boolean> {
|
||||
const chat = await this.chat;
|
||||
const history = chat.getHistory(true); // Get curated history
|
||||
|
|
|
@ -9,6 +9,8 @@ import {
|
|||
GenerateContentResponse,
|
||||
GenerateContentParameters,
|
||||
CountTokensParameters,
|
||||
EmbedContentResponse,
|
||||
EmbedContentParameters,
|
||||
} from '@google/genai';
|
||||
|
||||
/**
|
||||
|
@ -24,4 +26,6 @@ export interface ContentGenerator {
|
|||
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
||||
|
||||
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
||||
|
||||
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
|
||||
}
|
||||
|
|
|
@ -126,6 +126,7 @@ class MockTool extends BaseTool<{ param: string }, ToolResult> {
|
|||
const baseConfigParams: ConfigParameters = {
|
||||
apiKey: 'test-api-key',
|
||||
model: 'test-model',
|
||||
embeddingModel: 'test-embedding-model',
|
||||
sandbox: false,
|
||||
targetDir: '/test/dir',
|
||||
debugMode: false,
|
||||
|
|
Loading…
Reference in New Issue