Add embedder (#818)

This commit is contained in:
Eddie Santos 2025-06-07 13:38:05 -07:00 committed by GitHub
parent 51cd5ffd91
commit 27fdd1b6e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 206 additions and 21 deletions

View File

@ -33,6 +33,7 @@ const logger = {
export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-06-05';
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash-preview-05-20';
export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001';
interface CliArgs {
model: string | undefined;
@ -177,6 +178,7 @@ export async function loadCliConfig(
const configParams: ConfigParameters = {
apiKey: apiKeyForServer,
model: modelToUse,
embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
sandbox: argv.sandbox ?? settings.sandbox ?? argv.yolo ?? false,
targetDir: process.cwd(),
debugMode,

View File

@ -38,6 +38,7 @@ interface MockServerConfig {
vertexai?: boolean;
showMemoryUsage?: boolean;
accessibility?: AccessibilitySettings;
embeddingModel: string;
getApiKey: Mock<() => string>;
getModel: Mock<() => string>;
@ -92,6 +93,7 @@ vi.mock('@gemini-code/core', async (importOriginal) => {
vertexai: opts.vertexai,
showMemoryUsage: opts.showMemoryUsage ?? false,
accessibility: opts.accessibility ?? {},
embeddingModel: opts.embeddingModel || 'test-embedding-model',
getApiKey: vi.fn(() => opts.apiKey || 'test-key'),
getModel: vi.fn(() => opts.model || 'test-model-in-mock-factory'),
@ -178,7 +180,8 @@ describe('App UI', () => {
const ServerConfigMocked = vi.mocked(ServerConfig, true);
mockConfig = new ServerConfigMocked({
apiKey: 'test-key',
model: 'test-model-in-options',
model: 'test-model',
embeddingModel: 'test-embedding-model',
sandbox: false,
targetDir: '/test/dir',
debugMode: false,

View File

@ -48,9 +48,11 @@ describe('Server Config (config.ts)', () => {
const USER_AGENT = 'ServerTestAgent/1.0';
const USER_MEMORY = 'Test User Memory';
const TELEMETRY = false;
const EMBEDDING_MODEL = 'gemini-embedding';
const baseParams: ConfigParameters = {
apiKey: API_KEY,
model: MODEL,
embeddingModel: EMBEDDING_MODEL,
sandbox: SANDBOX,
targetDir: TARGET_DIR,
debugMode: DEBUG_MODE,

View File

@ -55,6 +55,7 @@ export class MCPServerConfig {
export interface ConfigParameters {
apiKey: string;
model: string;
embeddingModel: string;
sandbox: boolean | string;
targetDir: string;
debugMode: boolean;
@ -84,6 +85,7 @@ export class Config {
private toolRegistry: Promise<ToolRegistry>;
private readonly apiKey: string;
private readonly model: string;
private readonly embeddingModel: string;
private readonly sandbox: boolean | string;
private readonly targetDir: string;
private readonly debugMode: boolean;
@ -113,6 +115,7 @@ export class Config {
constructor(params: ConfigParameters) {
this.apiKey = params.apiKey;
this.model = params.model;
this.embeddingModel = params.embeddingModel;
this.sandbox = params.sandbox;
this.targetDir = path.resolve(params.targetDir);
this.debugMode = params.debugMode;
@ -163,6 +166,10 @@ export class Config {
return this.model;
}
getEmbeddingModel(): string {
return this.embeddingModel;
}
getSandbox(): boolean | string {
return this.sandbox;
}

View File

@ -6,29 +6,23 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { Chat, GenerateContentResponse } from '@google/genai';
import {
Chat,
EmbedContentResponse,
GenerateContentResponse,
GoogleGenAI,
} from '@google/genai';
import { GeminiClient } from './client.js';
import { Config } from '../config/config.js';
// --- Mocks ---
const mockChatCreateFn = vi.fn();
const mockGenerateContentFn = vi.fn();
const mockEmbedContentFn = vi.fn();
vi.mock('@google/genai', async (importOriginal) => {
const actual = await importOriginal<typeof import('@google/genai')>();
const MockedGoogleGenerativeAI = vi
.fn()
.mockImplementation((/*...args*/) => ({
chats: { create: mockChatCreateFn },
models: { generateContent: mockGenerateContentFn },
}));
return {
...actual,
GoogleGenerativeAI: MockedGoogleGenerativeAI,
Chat: vi.fn(),
Type: actual.Type ?? { OBJECT: 'OBJECT', STRING: 'STRING' },
};
});
vi.mock('@google/genai');
vi.mock('../config/config');
vi.mock('../config/config.js');
vi.mock('./prompts');
vi.mock('../utils/getFolderStructure', () => ({
getFolderStructure: vi.fn().mockResolvedValue('Mock Folder Structure'),
@ -44,8 +38,24 @@ vi.mock('../utils/generateContentResponseUtilities', () => ({
}));
describe('Gemini Client (client.ts)', () => {
let client: GeminiClient;
beforeEach(() => {
vi.resetAllMocks();
// Set up the mock for GoogleGenAI constructor and its methods
const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
MockedGoogleGenAI.mockImplementation(() => {
const mock = {
chats: { create: mockChatCreateFn },
models: {
generateContent: mockGenerateContentFn,
embedContent: mockEmbedContentFn,
},
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return mock as any;
});
mockChatCreateFn.mockResolvedValue({} as Chat);
mockGenerateContentFn.mockResolvedValue({
candidates: [
@ -56,6 +66,35 @@ describe('Gemini Client (client.ts)', () => {
},
],
} as unknown as GenerateContentResponse);
// Because the GeminiClient constructor kicks off an async process (startChat)
// that depends on a fully-formed Config object, we need to mock the
// entire implementation of Config for these tests.
const mockToolRegistry = {
getFunctionDeclarations: vi.fn().mockReturnValue([]),
getTool: vi.fn().mockReturnValue(null),
};
const MockedConfig = vi.mocked(Config, true);
MockedConfig.mockImplementation(() => {
const mock = {
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
getModel: vi.fn().mockReturnValue('test-model'),
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
getApiKey: vi.fn().mockReturnValue('test-key'),
getVertexAI: vi.fn().mockReturnValue(false),
getUserAgent: vi.fn().mockReturnValue('test-agent'),
getUserMemory: vi.fn().mockReturnValue(''),
getFullContext: vi.fn().mockReturnValue(false),
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return mock as any;
});
// We can instantiate the client here since Config is mocked
// and the constructor will use the mocked GoogleGenAI
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const mockConfig = new Config({} as any);
client = new GeminiClient(mockConfig);
});
afterEach(() => {
@ -82,8 +121,97 @@ describe('Gemini Client (client.ts)', () => {
// it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... });
// it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
// Add a placeholder test to keep the suite valid
it('should have a placeholder test', () => {
expect(true).toBe(true);
describe('generateEmbedding', () => {
const texts = ['hello world', 'goodbye world'];
const testEmbeddingModel = 'test-embedding-model';
it('should call embedContent with correct parameters and return embeddings', async () => {
const mockEmbeddings = [
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
];
const mockResponse: EmbedContentResponse = {
embeddings: [
{ values: mockEmbeddings[0] },
{ values: mockEmbeddings[1] },
],
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
const result = await client.generateEmbedding(texts);
expect(mockEmbedContentFn).toHaveBeenCalledTimes(1);
expect(mockEmbedContentFn).toHaveBeenCalledWith({
model: testEmbeddingModel,
contents: texts,
});
expect(result).toEqual(mockEmbeddings);
});
it('should return an empty array if an empty array is passed', async () => {
const result = await client.generateEmbedding([]);
expect(result).toEqual([]);
expect(mockEmbedContentFn).not.toHaveBeenCalled();
});
it('should throw an error if API response has no embeddings array', async () => {
mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'No embeddings found in API response.',
);
});
it('should throw an error if API response has an empty embeddings array', async () => {
const mockResponse: EmbedContentResponse = {
embeddings: [],
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'No embeddings found in API response.',
);
});
it('should throw an error if API returns a mismatched number of embeddings', async () => {
const mockResponse: EmbedContentResponse = {
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned a mismatched number of embeddings. Expected 2, got 1.',
);
});
it('should throw an error if any embedding has nullish values', async () => {
const mockResponse: EmbedContentResponse = {
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned an empty embedding for input text at index 1: "goodbye world"',
);
});
it('should throw an error if any embedding has an empty values array', async () => {
const mockResponse: EmbedContentResponse = {
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
};
mockEmbedContentFn.mockResolvedValue(mockResponse);
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API returned an empty embedding for input text at index 0: "hello world"',
);
});
it('should propagate errors from the API call', async () => {
const apiError = new Error('API Failure');
mockEmbedContentFn.mockRejectedValue(apiError);
await expect(client.generateEmbedding(texts)).rejects.toThrow(
'API Failure',
);
});
});
});

View File

@ -5,6 +5,8 @@
*/
import {
EmbedContentResponse,
EmbedContentParameters,
GenerateContentConfig,
GoogleGenAI,
Part,
@ -38,6 +40,7 @@ export class GeminiClient {
private chat: Promise<GeminiChat>;
private contentGenerator: ContentGenerator;
private model: string;
private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
@ -60,6 +63,7 @@ export class GeminiClient {
});
this.contentGenerator = googleGenAI.models;
this.model = config.getModel();
this.embeddingModel = config.getEmbeddingModel();
this.chat = this.startChat();
}
@ -450,6 +454,40 @@ export class GeminiClient {
}
}
async generateEmbedding(texts: string[]): Promise<number[][]> {
if (!texts || texts.length === 0) {
return [];
}
const embedModelParams: EmbedContentParameters = {
model: this.embeddingModel,
contents: texts,
};
const embedContentResponse: EmbedContentResponse =
await this.contentGenerator.embedContent(embedModelParams);
if (
!embedContentResponse.embeddings ||
embedContentResponse.embeddings.length === 0
) {
throw new Error('No embeddings found in API response.');
}
if (embedContentResponse.embeddings.length !== texts.length) {
throw new Error(
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
);
}
return embedContentResponse.embeddings.map((embedding, index) => {
const values = embedding.values;
if (!values || values.length === 0) {
throw new Error(
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
);
}
return values;
});
}
private async tryCompressChat(): Promise<boolean> {
const chat = await this.chat;
const history = chat.getHistory(true); // Get curated history

View File

@ -9,6 +9,8 @@ import {
GenerateContentResponse,
GenerateContentParameters,
CountTokensParameters,
EmbedContentResponse,
EmbedContentParameters,
} from '@google/genai';
/**
@ -24,4 +26,6 @@ export interface ContentGenerator {
): Promise<AsyncGenerator<GenerateContentResponse>>;
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
}

View File

@ -126,6 +126,7 @@ class MockTool extends BaseTool<{ param: string }, ToolResult> {
const baseConfigParams: ConfigParameters = {
apiKey: 'test-api-key',
model: 'test-model',
embeddingModel: 'test-embedding-model',
sandbox: false,
targetDir: '/test/dir',
debugMode: false,