Refactor: Centralize GeminiClient in Config (#693)

Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
Scott Densmore 2025-06-02 14:55:51 -07:00 committed by GitHub
parent 1dcf0a4cbd
commit e428707e07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 79 additions and 50 deletions

View File

@ -21,6 +21,7 @@ import { WebFetchTool } from '../tools/web-fetch.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js'; import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
import { WebSearchTool } from '../tools/web-search.js'; import { WebSearchTool } from '../tools/web-search.js';
import { GeminiClient } from '../core/client.js';
import { GEMINI_CONFIG_DIR as GEMINI_DIR } from '../tools/memoryTool.js'; import { GEMINI_CONFIG_DIR as GEMINI_DIR } from '../tools/memoryTool.js';
export enum ApprovalMode { export enum ApprovalMode {
@ -86,6 +87,7 @@ export class Config {
private approvalMode: ApprovalMode; private approvalMode: ApprovalMode;
private readonly vertexai: boolean | undefined; private readonly vertexai: boolean | undefined;
private readonly showMemoryUsage: boolean; private readonly showMemoryUsage: boolean;
private readonly geminiClient: GeminiClient;
constructor(params: ConfigParameters) { constructor(params: ConfigParameters) {
this.apiKey = params.apiKey; this.apiKey = params.apiKey;
@ -112,6 +114,7 @@ export class Config {
} }
this.toolRegistry = createToolRegistry(this); this.toolRegistry = createToolRegistry(this);
this.geminiClient = new GeminiClient(this);
} }
getApiKey(): string { getApiKey(): string {
@ -200,6 +203,10 @@ export class Config {
getShowMemoryUsage(): boolean { getShowMemoryUsage(): boolean {
return this.showMemoryUsage; return this.showMemoryUsage;
} }
getGeminiClient(): GeminiClient {
return this.geminiClient;
}
} }
function findEnvFile(startDir: string): string | null { function findEnvFile(startDir: string): string | null {

View File

@ -12,6 +12,7 @@ import {
PartListUnion, PartListUnion,
Content, Content,
Tool, Tool,
GenerateContentResponse,
} from '@google/genai'; } from '@google/genai';
import process from 'node:process'; import process from 'node:process';
import { getFolderStructure } from '../utils/getFolderStructure.js'; import { getFolderStructure } from '../utils/getFolderStructure.js';
@ -262,4 +263,56 @@ export class GeminiClient {
throw new Error(`Failed to generate JSON content: ${message}`); throw new Error(`Failed to generate JSON content: ${message}`);
} }
} }
async generateContent(
contents: Content[],
generationConfig: GenerateContentConfig,
abortSignal: AbortSignal,
): Promise<GenerateContentResponse> {
const modelToUse = this.model;
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
};
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = {
abortSignal,
...configToUse,
systemInstruction,
};
const apiCall = () =>
this.client.models.generateContent({
model: modelToUse,
config: requestConfig,
contents,
});
const result = await retryWithBackoff(apiCall);
return result;
} catch (error) {
if (abortSignal.aborted) {
throw error;
}
await reportError(
error,
`Error generating content via API with model ${modelToUse}.`,
{
requestContents: contents,
requestConfig: configToUse,
},
'generateContent-api',
);
const message =
error instanceof Error ? error.message : 'Unknown API error.';
throw new Error(
`Failed to generate content with model ${modelToUse}: ${message}`,
);
}
}
} }

View File

@ -4,13 +4,12 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { GoogleGenAI, GroundingMetadata } from '@google/genai'; import { GroundingMetadata } from '@google/genai';
import { SchemaValidator } from '../utils/schemaValidator.js'; import { SchemaValidator } from '../utils/schemaValidator.js';
import { BaseTool, ToolResult } from './tools.js'; import { BaseTool, ToolResult } from './tools.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { retryWithBackoff } from '../utils/retry.js';
// Interfaces for grounding metadata (similar to web-search.ts) // Interfaces for grounding metadata (similar to web-search.ts)
interface GroundingChunkWeb { interface GroundingChunkWeb {
@ -49,9 +48,6 @@ export interface WebFetchToolParams {
export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> { export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
static readonly Name: string = 'web_fetch'; static readonly Name: string = 'web_fetch';
private ai: GoogleGenAI;
private modelName: string;
constructor(private readonly config: Config) { constructor(private readonly config: Config) {
super( super(
WebFetchTool.Name, WebFetchTool.Name,
@ -69,12 +65,6 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
type: 'object', type: 'object',
}, },
); );
const apiKeyFromConfig = this.config.getApiKey();
this.ai = new GoogleGenAI({
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
});
this.modelName = this.config.getModel();
} }
validateParams(params: WebFetchToolParams): string | null { validateParams(params: WebFetchToolParams): string | null {
@ -109,7 +99,7 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
async execute( async execute(
params: WebFetchToolParams, params: WebFetchToolParams,
_signal: AbortSignal, signal: AbortSignal,
): Promise<ToolResult> { ): Promise<ToolResult> {
const validationError = this.validateParams(params); const validationError = this.validateParams(params);
if (validationError) { if (validationError) {
@ -120,23 +110,14 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
} }
const userPrompt = params.prompt; const userPrompt = params.prompt;
const geminiClient = this.config.getGeminiClient();
try { try {
const apiCall = () => const response = await geminiClient.generateContent(
this.ai.models.generateContent({ [{ role: 'user', parts: [{ text: userPrompt }] }],
model: this.modelName, { tools: [{ urlContext: {} }] },
contents: [ signal, // Pass signal
{ );
role: 'user',
parts: [{ text: userPrompt }],
},
],
config: {
tools: [{ urlContext: {} }],
},
});
const response = await retryWithBackoff(apiCall);
console.debug( console.debug(
`[WebFetchTool] Full response for prompt "${userPrompt.substring(0, 50)}...":`, `[WebFetchTool] Full response for prompt "${userPrompt.substring(0, 50)}...":`,

View File

@ -4,14 +4,13 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { GoogleGenAI, GroundingMetadata } from '@google/genai'; import { GroundingMetadata } from '@google/genai';
import { BaseTool, ToolResult } from './tools.js'; import { BaseTool, ToolResult } from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js'; import { SchemaValidator } from '../utils/schemaValidator.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { retryWithBackoff } from '../utils/retry.js';
interface GroundingChunkWeb { interface GroundingChunkWeb {
uri?: string; uri?: string;
@ -64,9 +63,6 @@ export class WebSearchTool extends BaseTool<
> { > {
static readonly Name: string = 'google_web_search'; static readonly Name: string = 'google_web_search';
private ai: GoogleGenAI;
private modelName: string;
constructor(private readonly config: Config) { constructor(private readonly config: Config) {
super( super(
WebSearchTool.Name, WebSearchTool.Name,
@ -83,13 +79,6 @@ export class WebSearchTool extends BaseTool<
required: ['query'], required: ['query'],
}, },
); );
const apiKeyFromConfig = this.config.getApiKey();
// Initialize GoogleGenAI, allowing fallback to environment variables for API key
this.ai = new GoogleGenAI({
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
});
this.modelName = this.config.getModel();
} }
validateParams(params: WebSearchToolParams): string | null { validateParams(params: WebSearchToolParams): string | null {
@ -112,7 +101,10 @@ export class WebSearchTool extends BaseTool<
return `Searching the web for: "${params.query}"`; return `Searching the web for: "${params.query}"`;
} }
async execute(params: WebSearchToolParams): Promise<WebSearchToolResult> { async execute(
params: WebSearchToolParams,
signal: AbortSignal,
): Promise<WebSearchToolResult> {
const validationError = this.validateParams(params); const validationError = this.validateParams(params);
if (validationError) { if (validationError) {
return { return {
@ -120,18 +112,14 @@ export class WebSearchTool extends BaseTool<
returnDisplay: validationError, returnDisplay: validationError,
}; };
} }
const geminiClient = this.config.getGeminiClient();
try { try {
const apiCall = () => const response = await geminiClient.generateContent(
this.ai.models.generateContent({ [{ role: 'user', parts: [{ text: params.query }] }],
model: this.modelName, { tools: [{ googleSearch: {} }] },
contents: [{ role: 'user', parts: [{ text: params.query }] }], signal,
config: { );
tools: [{ googleSearch: {} }],
},
});
const response = await retryWithBackoff(apiCall);
const responseText = getResponseText(response); const responseText = getResponseText(response);
const groundingMetadata = response.candidates?.[0]?.groundingMetadata; const groundingMetadata = response.candidates?.[0]?.groundingMetadata;