From 27fdd1b6e6e50882ee9a17c85c5f6b845d4973ce Mon Sep 17 00:00:00 2001 From: Eddie Santos <9561596+eddie-santos@users.noreply.github.com> Date: Sat, 7 Jun 2025 13:38:05 -0700 Subject: [PATCH] Add embedder (#818) --- packages/cli/src/config/config.ts | 2 + packages/cli/src/ui/App.test.tsx | 5 +- packages/core/src/config/config.test.ts | 2 + packages/core/src/config/config.ts | 7 + packages/core/src/core/client.test.ts | 168 +++++++++++++++--- packages/core/src/core/client.ts | 38 ++++ packages/core/src/core/contentGenerator.ts | 4 + packages/core/src/tools/tool-registry.test.ts | 1 + 8 files changed, 206 insertions(+), 21 deletions(-) diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 001d17d5..6ab1453f 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -33,6 +33,7 @@ const logger = { export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-06-05'; export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash-preview-05-20'; +export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001'; interface CliArgs { model: string | undefined; @@ -177,6 +178,7 @@ export async function loadCliConfig( const configParams: ConfigParameters = { apiKey: apiKeyForServer, model: modelToUse, + embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL, sandbox: argv.sandbox ?? settings.sandbox ?? argv.yolo ?? false, targetDir: process.cwd(), debugMode, diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index 98d82be8..f4ada985 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -38,6 +38,7 @@ interface MockServerConfig { vertexai?: boolean; showMemoryUsage?: boolean; accessibility?: AccessibilitySettings; + embeddingModel: string; getApiKey: Mock<() => string>; getModel: Mock<() => string>; @@ -92,6 +93,7 @@ vi.mock('@gemini-code/core', async (importOriginal) => { vertexai: opts.vertexai, showMemoryUsage: opts.showMemoryUsage ?? false, accessibility: opts.accessibility ?? {}, + embeddingModel: opts.embeddingModel || 'test-embedding-model', getApiKey: vi.fn(() => opts.apiKey || 'test-key'), getModel: vi.fn(() => opts.model || 'test-model-in-mock-factory'), @@ -178,7 +180,8 @@ describe('App UI', () => { const ServerConfigMocked = vi.mocked(ServerConfig, true); mockConfig = new ServerConfigMocked({ apiKey: 'test-key', - model: 'test-model-in-options', + model: 'test-model', + embeddingModel: 'test-embedding-model', sandbox: false, targetDir: '/test/dir', debugMode: false, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 411b124d..3800585d 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -48,9 +48,11 @@ describe('Server Config (config.ts)', () => { const USER_AGENT = 'ServerTestAgent/1.0'; const USER_MEMORY = 'Test User Memory'; const TELEMETRY = false; + const EMBEDDING_MODEL = 'gemini-embedding'; const baseParams: ConfigParameters = { apiKey: API_KEY, model: MODEL, + embeddingModel: EMBEDDING_MODEL, sandbox: SANDBOX, targetDir: TARGET_DIR, debugMode: DEBUG_MODE, diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 00b3e35d..75db970b 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -55,6 +55,7 @@ export class MCPServerConfig { export interface ConfigParameters { apiKey: string; model: string; + embeddingModel: string; sandbox: boolean | string; targetDir: string; debugMode: boolean; @@ -84,6 +85,7 @@ export class Config { private toolRegistry: Promise; private readonly apiKey: string; private readonly model: string; + private readonly embeddingModel: string; private readonly sandbox: boolean | string; private readonly targetDir: string; private readonly debugMode: boolean; @@ -113,6 +115,7 @@ export class Config { constructor(params: ConfigParameters) { this.apiKey = params.apiKey; this.model = params.model; + this.embeddingModel = params.embeddingModel; this.sandbox = params.sandbox; this.targetDir = path.resolve(params.targetDir); this.debugMode = params.debugMode; @@ -163,6 +166,10 @@ export class Config { return this.model; } + getEmbeddingModel(): string { + return this.embeddingModel; + } + getSandbox(): boolean | string { return this.sandbox; } diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 228701d8..9c12423c 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -6,29 +6,23 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; -import { Chat, GenerateContentResponse } from '@google/genai'; +import { + Chat, + EmbedContentResponse, + GenerateContentResponse, + GoogleGenAI, +} from '@google/genai'; +import { GeminiClient } from './client.js'; +import { Config } from '../config/config.js'; // --- Mocks --- const mockChatCreateFn = vi.fn(); const mockGenerateContentFn = vi.fn(); +const mockEmbedContentFn = vi.fn(); -vi.mock('@google/genai', async (importOriginal) => { - const actual = await importOriginal(); - const MockedGoogleGenerativeAI = vi - .fn() - .mockImplementation((/*...args*/) => ({ - chats: { create: mockChatCreateFn }, - models: { generateContent: mockGenerateContentFn }, - })); - return { - ...actual, - GoogleGenerativeAI: MockedGoogleGenerativeAI, - Chat: vi.fn(), - Type: actual.Type ?? { OBJECT: 'OBJECT', STRING: 'STRING' }, - }; -}); +vi.mock('@google/genai'); -vi.mock('../config/config'); +vi.mock('../config/config.js'); vi.mock('./prompts'); vi.mock('../utils/getFolderStructure', () => ({ getFolderStructure: vi.fn().mockResolvedValue('Mock Folder Structure'), @@ -44,8 +38,24 @@ vi.mock('../utils/generateContentResponseUtilities', () => ({ })); describe('Gemini Client (client.ts)', () => { + let client: GeminiClient; beforeEach(() => { vi.resetAllMocks(); + + // Set up the mock for GoogleGenAI constructor and its methods + const MockedGoogleGenAI = vi.mocked(GoogleGenAI); + MockedGoogleGenAI.mockImplementation(() => { + const mock = { + chats: { create: mockChatCreateFn }, + models: { + generateContent: mockGenerateContentFn, + embedContent: mockEmbedContentFn, + }, + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return mock as any; + }); + mockChatCreateFn.mockResolvedValue({} as Chat); mockGenerateContentFn.mockResolvedValue({ candidates: [ @@ -56,6 +66,35 @@ describe('Gemini Client (client.ts)', () => { }, ], } as unknown as GenerateContentResponse); + + // Because the GeminiClient constructor kicks off an async process (startChat) + // that depends on a fully-formed Config object, we need to mock the + // entire implementation of Config for these tests. + const mockToolRegistry = { + getFunctionDeclarations: vi.fn().mockReturnValue([]), + getTool: vi.fn().mockReturnValue(null), + }; + const MockedConfig = vi.mocked(Config, true); + MockedConfig.mockImplementation(() => { + const mock = { + getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry), + getModel: vi.fn().mockReturnValue('test-model'), + getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'), + getApiKey: vi.fn().mockReturnValue('test-key'), + getVertexAI: vi.fn().mockReturnValue(false), + getUserAgent: vi.fn().mockReturnValue('test-agent'), + getUserMemory: vi.fn().mockReturnValue(''), + getFullContext: vi.fn().mockReturnValue(false), + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return mock as any; + }); + + // We can instantiate the client here since Config is mocked + // and the constructor will use the mocked GoogleGenAI + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const mockConfig = new Config({} as any); + client = new GeminiClient(mockConfig); }); afterEach(() => { @@ -82,8 +121,97 @@ describe('Gemini Client (client.ts)', () => { // it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... }); // it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... }); - // Add a placeholder test to keep the suite valid - it('should have a placeholder test', () => { - expect(true).toBe(true); + describe('generateEmbedding', () => { + const texts = ['hello world', 'goodbye world']; + const testEmbeddingModel = 'test-embedding-model'; + + it('should call embedContent with correct parameters and return embeddings', async () => { + const mockEmbeddings = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + ]; + const mockResponse: EmbedContentResponse = { + embeddings: [ + { values: mockEmbeddings[0] }, + { values: mockEmbeddings[1] }, + ], + }; + mockEmbedContentFn.mockResolvedValue(mockResponse); + + const result = await client.generateEmbedding(texts); + + expect(mockEmbedContentFn).toHaveBeenCalledTimes(1); + expect(mockEmbedContentFn).toHaveBeenCalledWith({ + model: testEmbeddingModel, + contents: texts, + }); + expect(result).toEqual(mockEmbeddings); + }); + + it('should return an empty array if an empty array is passed', async () => { + const result = await client.generateEmbedding([]); + expect(result).toEqual([]); + expect(mockEmbedContentFn).not.toHaveBeenCalled(); + }); + + it('should throw an error if API response has no embeddings array', async () => { + mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'No embeddings found in API response.', + ); + }); + + it('should throw an error if API response has an empty embeddings array', async () => { + const mockResponse: EmbedContentResponse = { + embeddings: [], + }; + mockEmbedContentFn.mockResolvedValue(mockResponse); + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'No embeddings found in API response.', + ); + }); + + it('should throw an error if API returns a mismatched number of embeddings', async () => { + const mockResponse: EmbedContentResponse = { + embeddings: [{ values: [1, 2, 3] }], // Only one for two texts + }; + mockEmbedContentFn.mockResolvedValue(mockResponse); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'API returned a mismatched number of embeddings. Expected 2, got 1.', + ); + }); + + it('should throw an error if any embedding has nullish values', async () => { + const mockResponse: EmbedContentResponse = { + embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad + }; + mockEmbedContentFn.mockResolvedValue(mockResponse); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'API returned an empty embedding for input text at index 1: "goodbye world"', + ); + }); + + it('should throw an error if any embedding has an empty values array', async () => { + const mockResponse: EmbedContentResponse = { + embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad + }; + mockEmbedContentFn.mockResolvedValue(mockResponse); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'API returned an empty embedding for input text at index 0: "hello world"', + ); + }); + + it('should propagate errors from the API call', async () => { + const apiError = new Error('API Failure'); + mockEmbedContentFn.mockRejectedValue(apiError); + + await expect(client.generateEmbedding(texts)).rejects.toThrow( + 'API Failure', + ); + }); }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index d1a59eb1..c4515f93 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -5,6 +5,8 @@ */ import { + EmbedContentResponse, + EmbedContentParameters, GenerateContentConfig, GoogleGenAI, Part, @@ -38,6 +40,7 @@ export class GeminiClient { private chat: Promise; private contentGenerator: ContentGenerator; private model: string; + private embeddingModel: string; private generateContentConfig: GenerateContentConfig = { temperature: 0, topP: 1, @@ -60,6 +63,7 @@ export class GeminiClient { }); this.contentGenerator = googleGenAI.models; this.model = config.getModel(); + this.embeddingModel = config.getEmbeddingModel(); this.chat = this.startChat(); } @@ -450,6 +454,40 @@ export class GeminiClient { } } + async generateEmbedding(texts: string[]): Promise { + if (!texts || texts.length === 0) { + return []; + } + const embedModelParams: EmbedContentParameters = { + model: this.embeddingModel, + contents: texts, + }; + const embedContentResponse: EmbedContentResponse = + await this.contentGenerator.embedContent(embedModelParams); + if ( + !embedContentResponse.embeddings || + embedContentResponse.embeddings.length === 0 + ) { + throw new Error('No embeddings found in API response.'); + } + + if (embedContentResponse.embeddings.length !== texts.length) { + throw new Error( + `API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`, + ); + } + + return embedContentResponse.embeddings.map((embedding, index) => { + const values = embedding.values; + if (!values || values.length === 0) { + throw new Error( + `API returned an empty embedding for input text at index ${index}: "${texts[index]}"`, + ); + } + return values; + }); + } + private async tryCompressChat(): Promise { const chat = await this.chat; const history = chat.getHistory(true); // Get curated history diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 32b48c5c..955cd152 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -9,6 +9,8 @@ import { GenerateContentResponse, GenerateContentParameters, CountTokensParameters, + EmbedContentResponse, + EmbedContentParameters, } from '@google/genai'; /** @@ -24,4 +26,6 @@ export interface ContentGenerator { ): Promise>; countTokens(request: CountTokensParameters): Promise; + + embedContent(request: EmbedContentParameters): Promise; } diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index f57f5bce..0c23a74e 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -126,6 +126,7 @@ class MockTool extends BaseTool<{ param: string }, ToolResult> { const baseConfigParams: ConfigParameters = { apiKey: 'test-api-key', model: 'test-model', + embeddingModel: 'test-embedding-model', sandbox: false, targetDir: '/test/dir', debugMode: false,