Add embedder (#818)
This commit is contained in:
parent
51cd5ffd91
commit
27fdd1b6e6
|
@ -33,6 +33,7 @@ const logger = {
|
||||||
|
|
||||||
export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-06-05';
|
export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-06-05';
|
||||||
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash-preview-05-20';
|
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash-preview-05-20';
|
||||||
|
export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001';
|
||||||
|
|
||||||
interface CliArgs {
|
interface CliArgs {
|
||||||
model: string | undefined;
|
model: string | undefined;
|
||||||
|
@ -177,6 +178,7 @@ export async function loadCliConfig(
|
||||||
const configParams: ConfigParameters = {
|
const configParams: ConfigParameters = {
|
||||||
apiKey: apiKeyForServer,
|
apiKey: apiKeyForServer,
|
||||||
model: modelToUse,
|
model: modelToUse,
|
||||||
|
embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||||
sandbox: argv.sandbox ?? settings.sandbox ?? argv.yolo ?? false,
|
sandbox: argv.sandbox ?? settings.sandbox ?? argv.yolo ?? false,
|
||||||
targetDir: process.cwd(),
|
targetDir: process.cwd(),
|
||||||
debugMode,
|
debugMode,
|
||||||
|
|
|
@ -38,6 +38,7 @@ interface MockServerConfig {
|
||||||
vertexai?: boolean;
|
vertexai?: boolean;
|
||||||
showMemoryUsage?: boolean;
|
showMemoryUsage?: boolean;
|
||||||
accessibility?: AccessibilitySettings;
|
accessibility?: AccessibilitySettings;
|
||||||
|
embeddingModel: string;
|
||||||
|
|
||||||
getApiKey: Mock<() => string>;
|
getApiKey: Mock<() => string>;
|
||||||
getModel: Mock<() => string>;
|
getModel: Mock<() => string>;
|
||||||
|
@ -92,6 +93,7 @@ vi.mock('@gemini-code/core', async (importOriginal) => {
|
||||||
vertexai: opts.vertexai,
|
vertexai: opts.vertexai,
|
||||||
showMemoryUsage: opts.showMemoryUsage ?? false,
|
showMemoryUsage: opts.showMemoryUsage ?? false,
|
||||||
accessibility: opts.accessibility ?? {},
|
accessibility: opts.accessibility ?? {},
|
||||||
|
embeddingModel: opts.embeddingModel || 'test-embedding-model',
|
||||||
|
|
||||||
getApiKey: vi.fn(() => opts.apiKey || 'test-key'),
|
getApiKey: vi.fn(() => opts.apiKey || 'test-key'),
|
||||||
getModel: vi.fn(() => opts.model || 'test-model-in-mock-factory'),
|
getModel: vi.fn(() => opts.model || 'test-model-in-mock-factory'),
|
||||||
|
@ -178,7 +180,8 @@ describe('App UI', () => {
|
||||||
const ServerConfigMocked = vi.mocked(ServerConfig, true);
|
const ServerConfigMocked = vi.mocked(ServerConfig, true);
|
||||||
mockConfig = new ServerConfigMocked({
|
mockConfig = new ServerConfigMocked({
|
||||||
apiKey: 'test-key',
|
apiKey: 'test-key',
|
||||||
model: 'test-model-in-options',
|
model: 'test-model',
|
||||||
|
embeddingModel: 'test-embedding-model',
|
||||||
sandbox: false,
|
sandbox: false,
|
||||||
targetDir: '/test/dir',
|
targetDir: '/test/dir',
|
||||||
debugMode: false,
|
debugMode: false,
|
||||||
|
|
|
@ -48,9 +48,11 @@ describe('Server Config (config.ts)', () => {
|
||||||
const USER_AGENT = 'ServerTestAgent/1.0';
|
const USER_AGENT = 'ServerTestAgent/1.0';
|
||||||
const USER_MEMORY = 'Test User Memory';
|
const USER_MEMORY = 'Test User Memory';
|
||||||
const TELEMETRY = false;
|
const TELEMETRY = false;
|
||||||
|
const EMBEDDING_MODEL = 'gemini-embedding';
|
||||||
const baseParams: ConfigParameters = {
|
const baseParams: ConfigParameters = {
|
||||||
apiKey: API_KEY,
|
apiKey: API_KEY,
|
||||||
model: MODEL,
|
model: MODEL,
|
||||||
|
embeddingModel: EMBEDDING_MODEL,
|
||||||
sandbox: SANDBOX,
|
sandbox: SANDBOX,
|
||||||
targetDir: TARGET_DIR,
|
targetDir: TARGET_DIR,
|
||||||
debugMode: DEBUG_MODE,
|
debugMode: DEBUG_MODE,
|
||||||
|
|
|
@ -55,6 +55,7 @@ export class MCPServerConfig {
|
||||||
export interface ConfigParameters {
|
export interface ConfigParameters {
|
||||||
apiKey: string;
|
apiKey: string;
|
||||||
model: string;
|
model: string;
|
||||||
|
embeddingModel: string;
|
||||||
sandbox: boolean | string;
|
sandbox: boolean | string;
|
||||||
targetDir: string;
|
targetDir: string;
|
||||||
debugMode: boolean;
|
debugMode: boolean;
|
||||||
|
@ -84,6 +85,7 @@ export class Config {
|
||||||
private toolRegistry: Promise<ToolRegistry>;
|
private toolRegistry: Promise<ToolRegistry>;
|
||||||
private readonly apiKey: string;
|
private readonly apiKey: string;
|
||||||
private readonly model: string;
|
private readonly model: string;
|
||||||
|
private readonly embeddingModel: string;
|
||||||
private readonly sandbox: boolean | string;
|
private readonly sandbox: boolean | string;
|
||||||
private readonly targetDir: string;
|
private readonly targetDir: string;
|
||||||
private readonly debugMode: boolean;
|
private readonly debugMode: boolean;
|
||||||
|
@ -113,6 +115,7 @@ export class Config {
|
||||||
constructor(params: ConfigParameters) {
|
constructor(params: ConfigParameters) {
|
||||||
this.apiKey = params.apiKey;
|
this.apiKey = params.apiKey;
|
||||||
this.model = params.model;
|
this.model = params.model;
|
||||||
|
this.embeddingModel = params.embeddingModel;
|
||||||
this.sandbox = params.sandbox;
|
this.sandbox = params.sandbox;
|
||||||
this.targetDir = path.resolve(params.targetDir);
|
this.targetDir = path.resolve(params.targetDir);
|
||||||
this.debugMode = params.debugMode;
|
this.debugMode = params.debugMode;
|
||||||
|
@ -163,6 +166,10 @@ export class Config {
|
||||||
return this.model;
|
return this.model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getEmbeddingModel(): string {
|
||||||
|
return this.embeddingModel;
|
||||||
|
}
|
||||||
|
|
||||||
getSandbox(): boolean | string {
|
getSandbox(): boolean | string {
|
||||||
return this.sandbox;
|
return this.sandbox;
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,29 +6,23 @@
|
||||||
|
|
||||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||||
|
|
||||||
import { Chat, GenerateContentResponse } from '@google/genai';
|
import {
|
||||||
|
Chat,
|
||||||
|
EmbedContentResponse,
|
||||||
|
GenerateContentResponse,
|
||||||
|
GoogleGenAI,
|
||||||
|
} from '@google/genai';
|
||||||
|
import { GeminiClient } from './client.js';
|
||||||
|
import { Config } from '../config/config.js';
|
||||||
|
|
||||||
// --- Mocks ---
|
// --- Mocks ---
|
||||||
const mockChatCreateFn = vi.fn();
|
const mockChatCreateFn = vi.fn();
|
||||||
const mockGenerateContentFn = vi.fn();
|
const mockGenerateContentFn = vi.fn();
|
||||||
|
const mockEmbedContentFn = vi.fn();
|
||||||
|
|
||||||
vi.mock('@google/genai', async (importOriginal) => {
|
vi.mock('@google/genai');
|
||||||
const actual = await importOriginal<typeof import('@google/genai')>();
|
|
||||||
const MockedGoogleGenerativeAI = vi
|
|
||||||
.fn()
|
|
||||||
.mockImplementation((/*...args*/) => ({
|
|
||||||
chats: { create: mockChatCreateFn },
|
|
||||||
models: { generateContent: mockGenerateContentFn },
|
|
||||||
}));
|
|
||||||
return {
|
|
||||||
...actual,
|
|
||||||
GoogleGenerativeAI: MockedGoogleGenerativeAI,
|
|
||||||
Chat: vi.fn(),
|
|
||||||
Type: actual.Type ?? { OBJECT: 'OBJECT', STRING: 'STRING' },
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
vi.mock('../config/config');
|
vi.mock('../config/config.js');
|
||||||
vi.mock('./prompts');
|
vi.mock('./prompts');
|
||||||
vi.mock('../utils/getFolderStructure', () => ({
|
vi.mock('../utils/getFolderStructure', () => ({
|
||||||
getFolderStructure: vi.fn().mockResolvedValue('Mock Folder Structure'),
|
getFolderStructure: vi.fn().mockResolvedValue('Mock Folder Structure'),
|
||||||
|
@ -44,8 +38,24 @@ vi.mock('../utils/generateContentResponseUtilities', () => ({
|
||||||
}));
|
}));
|
||||||
|
|
||||||
describe('Gemini Client (client.ts)', () => {
|
describe('Gemini Client (client.ts)', () => {
|
||||||
|
let client: GeminiClient;
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.resetAllMocks();
|
vi.resetAllMocks();
|
||||||
|
|
||||||
|
// Set up the mock for GoogleGenAI constructor and its methods
|
||||||
|
const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
|
||||||
|
MockedGoogleGenAI.mockImplementation(() => {
|
||||||
|
const mock = {
|
||||||
|
chats: { create: mockChatCreateFn },
|
||||||
|
models: {
|
||||||
|
generateContent: mockGenerateContentFn,
|
||||||
|
embedContent: mockEmbedContentFn,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
return mock as any;
|
||||||
|
});
|
||||||
|
|
||||||
mockChatCreateFn.mockResolvedValue({} as Chat);
|
mockChatCreateFn.mockResolvedValue({} as Chat);
|
||||||
mockGenerateContentFn.mockResolvedValue({
|
mockGenerateContentFn.mockResolvedValue({
|
||||||
candidates: [
|
candidates: [
|
||||||
|
@ -56,6 +66,35 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
} as unknown as GenerateContentResponse);
|
} as unknown as GenerateContentResponse);
|
||||||
|
|
||||||
|
// Because the GeminiClient constructor kicks off an async process (startChat)
|
||||||
|
// that depends on a fully-formed Config object, we need to mock the
|
||||||
|
// entire implementation of Config for these tests.
|
||||||
|
const mockToolRegistry = {
|
||||||
|
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
||||||
|
getTool: vi.fn().mockReturnValue(null),
|
||||||
|
};
|
||||||
|
const MockedConfig = vi.mocked(Config, true);
|
||||||
|
MockedConfig.mockImplementation(() => {
|
||||||
|
const mock = {
|
||||||
|
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
|
||||||
|
getModel: vi.fn().mockReturnValue('test-model'),
|
||||||
|
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
|
||||||
|
getApiKey: vi.fn().mockReturnValue('test-key'),
|
||||||
|
getVertexAI: vi.fn().mockReturnValue(false),
|
||||||
|
getUserAgent: vi.fn().mockReturnValue('test-agent'),
|
||||||
|
getUserMemory: vi.fn().mockReturnValue(''),
|
||||||
|
getFullContext: vi.fn().mockReturnValue(false),
|
||||||
|
};
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
return mock as any;
|
||||||
|
});
|
||||||
|
|
||||||
|
// We can instantiate the client here since Config is mocked
|
||||||
|
// and the constructor will use the mocked GoogleGenAI
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
const mockConfig = new Config({} as any);
|
||||||
|
client = new GeminiClient(mockConfig);
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
@ -82,8 +121,97 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
// it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... });
|
// it('generateJson should call getCoreSystemPrompt with userMemory and pass to generateContent', async () => { ... });
|
||||||
// it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
|
// it('generateJson should call getCoreSystemPrompt with empty string if userMemory is empty', async () => { ... });
|
||||||
|
|
||||||
// Add a placeholder test to keep the suite valid
|
describe('generateEmbedding', () => {
|
||||||
it('should have a placeholder test', () => {
|
const texts = ['hello world', 'goodbye world'];
|
||||||
expect(true).toBe(true);
|
const testEmbeddingModel = 'test-embedding-model';
|
||||||
|
|
||||||
|
it('should call embedContent with correct parameters and return embeddings', async () => {
|
||||||
|
const mockEmbeddings = [
|
||||||
|
[0.1, 0.2, 0.3],
|
||||||
|
[0.4, 0.5, 0.6],
|
||||||
|
];
|
||||||
|
const mockResponse: EmbedContentResponse = {
|
||||||
|
embeddings: [
|
||||||
|
{ values: mockEmbeddings[0] },
|
||||||
|
{ values: mockEmbeddings[1] },
|
||||||
|
],
|
||||||
|
};
|
||||||
|
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
|
const result = await client.generateEmbedding(texts);
|
||||||
|
|
||||||
|
expect(mockEmbedContentFn).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockEmbedContentFn).toHaveBeenCalledWith({
|
||||||
|
model: testEmbeddingModel,
|
||||||
|
contents: texts,
|
||||||
|
});
|
||||||
|
expect(result).toEqual(mockEmbeddings);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return an empty array if an empty array is passed', async () => {
|
||||||
|
const result = await client.generateEmbedding([]);
|
||||||
|
expect(result).toEqual([]);
|
||||||
|
expect(mockEmbedContentFn).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if API response has no embeddings array', async () => {
|
||||||
|
mockEmbedContentFn.mockResolvedValue({} as EmbedContentResponse); // No `embeddings` key
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'No embeddings found in API response.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if API response has an empty embeddings array', async () => {
|
||||||
|
const mockResponse: EmbedContentResponse = {
|
||||||
|
embeddings: [],
|
||||||
|
};
|
||||||
|
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'No embeddings found in API response.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if API returns a mismatched number of embeddings', async () => {
|
||||||
|
const mockResponse: EmbedContentResponse = {
|
||||||
|
embeddings: [{ values: [1, 2, 3] }], // Only one for two texts
|
||||||
|
};
|
||||||
|
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'API returned a mismatched number of embeddings. Expected 2, got 1.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if any embedding has nullish values', async () => {
|
||||||
|
const mockResponse: EmbedContentResponse = {
|
||||||
|
embeddings: [{ values: [1, 2, 3] }, { values: undefined }], // Second one is bad
|
||||||
|
};
|
||||||
|
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'API returned an empty embedding for input text at index 1: "goodbye world"',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if any embedding has an empty values array', async () => {
|
||||||
|
const mockResponse: EmbedContentResponse = {
|
||||||
|
embeddings: [{ values: [] }, { values: [1, 2, 3] }], // First one is bad
|
||||||
|
};
|
||||||
|
mockEmbedContentFn.mockResolvedValue(mockResponse);
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'API returned an empty embedding for input text at index 0: "hello world"',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should propagate errors from the API call', async () => {
|
||||||
|
const apiError = new Error('API Failure');
|
||||||
|
mockEmbedContentFn.mockRejectedValue(apiError);
|
||||||
|
|
||||||
|
await expect(client.generateEmbedding(texts)).rejects.toThrow(
|
||||||
|
'API Failure',
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
EmbedContentResponse,
|
||||||
|
EmbedContentParameters,
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
GoogleGenAI,
|
GoogleGenAI,
|
||||||
Part,
|
Part,
|
||||||
|
@ -38,6 +40,7 @@ export class GeminiClient {
|
||||||
private chat: Promise<GeminiChat>;
|
private chat: Promise<GeminiChat>;
|
||||||
private contentGenerator: ContentGenerator;
|
private contentGenerator: ContentGenerator;
|
||||||
private model: string;
|
private model: string;
|
||||||
|
private embeddingModel: string;
|
||||||
private generateContentConfig: GenerateContentConfig = {
|
private generateContentConfig: GenerateContentConfig = {
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
topP: 1,
|
topP: 1,
|
||||||
|
@ -60,6 +63,7 @@ export class GeminiClient {
|
||||||
});
|
});
|
||||||
this.contentGenerator = googleGenAI.models;
|
this.contentGenerator = googleGenAI.models;
|
||||||
this.model = config.getModel();
|
this.model = config.getModel();
|
||||||
|
this.embeddingModel = config.getEmbeddingModel();
|
||||||
this.chat = this.startChat();
|
this.chat = this.startChat();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -450,6 +454,40 @@ export class GeminiClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async generateEmbedding(texts: string[]): Promise<number[][]> {
|
||||||
|
if (!texts || texts.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const embedModelParams: EmbedContentParameters = {
|
||||||
|
model: this.embeddingModel,
|
||||||
|
contents: texts,
|
||||||
|
};
|
||||||
|
const embedContentResponse: EmbedContentResponse =
|
||||||
|
await this.contentGenerator.embedContent(embedModelParams);
|
||||||
|
if (
|
||||||
|
!embedContentResponse.embeddings ||
|
||||||
|
embedContentResponse.embeddings.length === 0
|
||||||
|
) {
|
||||||
|
throw new Error('No embeddings found in API response.');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embedContentResponse.embeddings.length !== texts.length) {
|
||||||
|
throw new Error(
|
||||||
|
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return embedContentResponse.embeddings.map((embedding, index) => {
|
||||||
|
const values = embedding.values;
|
||||||
|
if (!values || values.length === 0) {
|
||||||
|
throw new Error(
|
||||||
|
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return values;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private async tryCompressChat(): Promise<boolean> {
|
private async tryCompressChat(): Promise<boolean> {
|
||||||
const chat = await this.chat;
|
const chat = await this.chat;
|
||||||
const history = chat.getHistory(true); // Get curated history
|
const history = chat.getHistory(true); // Get curated history
|
||||||
|
|
|
@ -9,6 +9,8 @@ import {
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
GenerateContentParameters,
|
GenerateContentParameters,
|
||||||
CountTokensParameters,
|
CountTokensParameters,
|
||||||
|
EmbedContentResponse,
|
||||||
|
EmbedContentParameters,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -24,4 +26,6 @@ export interface ContentGenerator {
|
||||||
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
): Promise<AsyncGenerator<GenerateContentResponse>>;
|
||||||
|
|
||||||
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
|
||||||
|
|
||||||
|
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -126,6 +126,7 @@ class MockTool extends BaseTool<{ param: string }, ToolResult> {
|
||||||
const baseConfigParams: ConfigParameters = {
|
const baseConfigParams: ConfigParameters = {
|
||||||
apiKey: 'test-api-key',
|
apiKey: 'test-api-key',
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
|
embeddingModel: 'test-embedding-model',
|
||||||
sandbox: false,
|
sandbox: false,
|
||||||
targetDir: '/test/dir',
|
targetDir: '/test/dir',
|
||||||
debugMode: false,
|
debugMode: false,
|
||||||
|
|
Loading…
Reference in New Issue