diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index dc77208c..46e5123c 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -21,6 +21,7 @@ import { WebFetchTool } from '../tools/web-fetch.js'; import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js'; import { WebSearchTool } from '../tools/web-search.js'; +import { GeminiClient } from '../core/client.js'; import { GEMINI_CONFIG_DIR as GEMINI_DIR } from '../tools/memoryTool.js'; export enum ApprovalMode { @@ -86,6 +87,7 @@ export class Config { private approvalMode: ApprovalMode; private readonly vertexai: boolean | undefined; private readonly showMemoryUsage: boolean; + private readonly geminiClient: GeminiClient; constructor(params: ConfigParameters) { this.apiKey = params.apiKey; @@ -112,6 +114,7 @@ export class Config { } this.toolRegistry = createToolRegistry(this); + this.geminiClient = new GeminiClient(this); } getApiKey(): string { @@ -200,6 +203,10 @@ export class Config { getShowMemoryUsage(): boolean { return this.showMemoryUsage; } + + getGeminiClient(): GeminiClient { + return this.geminiClient; + } } function findEnvFile(startDir: string): string | null { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index db30ac16..732126cb 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -12,6 +12,7 @@ import { PartListUnion, Content, Tool, + GenerateContentResponse, } from '@google/genai'; import process from 'node:process'; import { getFolderStructure } from '../utils/getFolderStructure.js'; @@ -262,4 +263,56 @@ export class GeminiClient { throw new Error(`Failed to generate JSON content: ${message}`); } } + + async generateContent( + contents: Content[], + generationConfig: GenerateContentConfig, + abortSignal: AbortSignal, + ): Promise { + const modelToUse = this.model; + const configToUse: GenerateContentConfig = { + ...this.generateContentConfig, + ...generationConfig, + }; + + try { + const userMemory = this.config.getUserMemory(); + const systemInstruction = getCoreSystemPrompt(userMemory); + + const requestConfig = { + abortSignal, + ...configToUse, + systemInstruction, + }; + + const apiCall = () => + this.client.models.generateContent({ + model: modelToUse, + config: requestConfig, + contents, + }); + + const result = await retryWithBackoff(apiCall); + return result; + } catch (error) { + if (abortSignal.aborted) { + throw error; + } + + await reportError( + error, + `Error generating content via API with model ${modelToUse}.`, + { + requestContents: contents, + requestConfig: configToUse, + }, + 'generateContent-api', + ); + const message = + error instanceof Error ? error.message : 'Unknown API error.'; + throw new Error( + `Failed to generate content with model ${modelToUse}: ${message}`, + ); + } + } } diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 24617902..6a6048fc 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -4,13 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { GoogleGenAI, GroundingMetadata } from '@google/genai'; +import { GroundingMetadata } from '@google/genai'; import { SchemaValidator } from '../utils/schemaValidator.js'; import { BaseTool, ToolResult } from './tools.js'; import { getErrorMessage } from '../utils/errors.js'; import { Config } from '../config/config.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js'; -import { retryWithBackoff } from '../utils/retry.js'; // Interfaces for grounding metadata (similar to web-search.ts) interface GroundingChunkWeb { @@ -49,9 +48,6 @@ export interface WebFetchToolParams { export class WebFetchTool extends BaseTool { static readonly Name: string = 'web_fetch'; - private ai: GoogleGenAI; - private modelName: string; - constructor(private readonly config: Config) { super( WebFetchTool.Name, @@ -69,12 +65,6 @@ export class WebFetchTool extends BaseTool { type: 'object', }, ); - - const apiKeyFromConfig = this.config.getApiKey(); - this.ai = new GoogleGenAI({ - apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig, - }); - this.modelName = this.config.getModel(); } validateParams(params: WebFetchToolParams): string | null { @@ -109,7 +99,7 @@ export class WebFetchTool extends BaseTool { async execute( params: WebFetchToolParams, - _signal: AbortSignal, + signal: AbortSignal, ): Promise { const validationError = this.validateParams(params); if (validationError) { @@ -120,23 +110,14 @@ export class WebFetchTool extends BaseTool { } const userPrompt = params.prompt; + const geminiClient = this.config.getGeminiClient(); try { - const apiCall = () => - this.ai.models.generateContent({ - model: this.modelName, - contents: [ - { - role: 'user', - parts: [{ text: userPrompt }], - }, - ], - config: { - tools: [{ urlContext: {} }], - }, - }); - - const response = await retryWithBackoff(apiCall); + const response = await geminiClient.generateContent( + [{ role: 'user', parts: [{ text: userPrompt }] }], + { tools: [{ urlContext: {} }] }, + signal, // Pass signal + ); console.debug( `[WebFetchTool] Full response for prompt "${userPrompt.substring(0, 50)}...":`, diff --git a/packages/core/src/tools/web-search.ts b/packages/core/src/tools/web-search.ts index ed2f341f..c4dcc54a 100644 --- a/packages/core/src/tools/web-search.ts +++ b/packages/core/src/tools/web-search.ts @@ -4,14 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { GoogleGenAI, GroundingMetadata } from '@google/genai'; +import { GroundingMetadata } from '@google/genai'; import { BaseTool, ToolResult } from './tools.js'; import { SchemaValidator } from '../utils/schemaValidator.js'; import { getErrorMessage } from '../utils/errors.js'; import { Config } from '../config/config.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js'; -import { retryWithBackoff } from '../utils/retry.js'; interface GroundingChunkWeb { uri?: string; @@ -64,9 +63,6 @@ export class WebSearchTool extends BaseTool< > { static readonly Name: string = 'google_web_search'; - private ai: GoogleGenAI; - private modelName: string; - constructor(private readonly config: Config) { super( WebSearchTool.Name, @@ -83,13 +79,6 @@ export class WebSearchTool extends BaseTool< required: ['query'], }, ); - - const apiKeyFromConfig = this.config.getApiKey(); - // Initialize GoogleGenAI, allowing fallback to environment variables for API key - this.ai = new GoogleGenAI({ - apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig, - }); - this.modelName = this.config.getModel(); } validateParams(params: WebSearchToolParams): string | null { @@ -112,7 +101,10 @@ export class WebSearchTool extends BaseTool< return `Searching the web for: "${params.query}"`; } - async execute(params: WebSearchToolParams): Promise { + async execute( + params: WebSearchToolParams, + signal: AbortSignal, + ): Promise { const validationError = this.validateParams(params); if (validationError) { return { @@ -120,18 +112,14 @@ export class WebSearchTool extends BaseTool< returnDisplay: validationError, }; } + const geminiClient = this.config.getGeminiClient(); try { - const apiCall = () => - this.ai.models.generateContent({ - model: this.modelName, - contents: [{ role: 'user', parts: [{ text: params.query }] }], - config: { - tools: [{ googleSearch: {} }], - }, - }); - - const response = await retryWithBackoff(apiCall); + const response = await geminiClient.generateContent( + [{ role: 'user', parts: [{ text: params.query }] }], + { tools: [{ googleSearch: {} }] }, + signal, + ); const responseText = getResponseText(response); const groundingMetadata = response.candidates?.[0]?.groundingMetadata;