Refactor: Centralize GeminiClient in Config (#693)
Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
parent
1dcf0a4cbd
commit
e428707e07
|
@ -21,6 +21,7 @@ import { WebFetchTool } from '../tools/web-fetch.js';
|
||||||
import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
||||||
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
|
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
|
||||||
import { WebSearchTool } from '../tools/web-search.js';
|
import { WebSearchTool } from '../tools/web-search.js';
|
||||||
|
import { GeminiClient } from '../core/client.js';
|
||||||
import { GEMINI_CONFIG_DIR as GEMINI_DIR } from '../tools/memoryTool.js';
|
import { GEMINI_CONFIG_DIR as GEMINI_DIR } from '../tools/memoryTool.js';
|
||||||
|
|
||||||
export enum ApprovalMode {
|
export enum ApprovalMode {
|
||||||
|
@ -86,6 +87,7 @@ export class Config {
|
||||||
private approvalMode: ApprovalMode;
|
private approvalMode: ApprovalMode;
|
||||||
private readonly vertexai: boolean | undefined;
|
private readonly vertexai: boolean | undefined;
|
||||||
private readonly showMemoryUsage: boolean;
|
private readonly showMemoryUsage: boolean;
|
||||||
|
private readonly geminiClient: GeminiClient;
|
||||||
|
|
||||||
constructor(params: ConfigParameters) {
|
constructor(params: ConfigParameters) {
|
||||||
this.apiKey = params.apiKey;
|
this.apiKey = params.apiKey;
|
||||||
|
@ -112,6 +114,7 @@ export class Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
this.toolRegistry = createToolRegistry(this);
|
this.toolRegistry = createToolRegistry(this);
|
||||||
|
this.geminiClient = new GeminiClient(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
getApiKey(): string {
|
getApiKey(): string {
|
||||||
|
@ -200,6 +203,10 @@ export class Config {
|
||||||
getShowMemoryUsage(): boolean {
|
getShowMemoryUsage(): boolean {
|
||||||
return this.showMemoryUsage;
|
return this.showMemoryUsage;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getGeminiClient(): GeminiClient {
|
||||||
|
return this.geminiClient;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function findEnvFile(startDir: string): string | null {
|
function findEnvFile(startDir: string): string | null {
|
||||||
|
|
|
@ -12,6 +12,7 @@ import {
|
||||||
PartListUnion,
|
PartListUnion,
|
||||||
Content,
|
Content,
|
||||||
Tool,
|
Tool,
|
||||||
|
GenerateContentResponse,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import process from 'node:process';
|
import process from 'node:process';
|
||||||
import { getFolderStructure } from '../utils/getFolderStructure.js';
|
import { getFolderStructure } from '../utils/getFolderStructure.js';
|
||||||
|
@ -262,4 +263,56 @@ export class GeminiClient {
|
||||||
throw new Error(`Failed to generate JSON content: ${message}`);
|
throw new Error(`Failed to generate JSON content: ${message}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async generateContent(
|
||||||
|
contents: Content[],
|
||||||
|
generationConfig: GenerateContentConfig,
|
||||||
|
abortSignal: AbortSignal,
|
||||||
|
): Promise<GenerateContentResponse> {
|
||||||
|
const modelToUse = this.model;
|
||||||
|
const configToUse: GenerateContentConfig = {
|
||||||
|
...this.generateContentConfig,
|
||||||
|
...generationConfig,
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const userMemory = this.config.getUserMemory();
|
||||||
|
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||||
|
|
||||||
|
const requestConfig = {
|
||||||
|
abortSignal,
|
||||||
|
...configToUse,
|
||||||
|
systemInstruction,
|
||||||
|
};
|
||||||
|
|
||||||
|
const apiCall = () =>
|
||||||
|
this.client.models.generateContent({
|
||||||
|
model: modelToUse,
|
||||||
|
config: requestConfig,
|
||||||
|
contents,
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await retryWithBackoff(apiCall);
|
||||||
|
return result;
|
||||||
|
} catch (error) {
|
||||||
|
if (abortSignal.aborted) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
await reportError(
|
||||||
|
error,
|
||||||
|
`Error generating content via API with model ${modelToUse}.`,
|
||||||
|
{
|
||||||
|
requestContents: contents,
|
||||||
|
requestConfig: configToUse,
|
||||||
|
},
|
||||||
|
'generateContent-api',
|
||||||
|
);
|
||||||
|
const message =
|
||||||
|
error instanceof Error ? error.message : 'Unknown API error.';
|
||||||
|
throw new Error(
|
||||||
|
`Failed to generate content with model ${modelToUse}: ${message}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,13 +4,12 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { GoogleGenAI, GroundingMetadata } from '@google/genai';
|
import { GroundingMetadata } from '@google/genai';
|
||||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||||
import { BaseTool, ToolResult } from './tools.js';
|
import { BaseTool, ToolResult } from './tools.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||||
import { retryWithBackoff } from '../utils/retry.js';
|
|
||||||
|
|
||||||
// Interfaces for grounding metadata (similar to web-search.ts)
|
// Interfaces for grounding metadata (similar to web-search.ts)
|
||||||
interface GroundingChunkWeb {
|
interface GroundingChunkWeb {
|
||||||
|
@ -49,9 +48,6 @@ export interface WebFetchToolParams {
|
||||||
export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
||||||
static readonly Name: string = 'web_fetch';
|
static readonly Name: string = 'web_fetch';
|
||||||
|
|
||||||
private ai: GoogleGenAI;
|
|
||||||
private modelName: string;
|
|
||||||
|
|
||||||
constructor(private readonly config: Config) {
|
constructor(private readonly config: Config) {
|
||||||
super(
|
super(
|
||||||
WebFetchTool.Name,
|
WebFetchTool.Name,
|
||||||
|
@ -69,12 +65,6 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
||||||
type: 'object',
|
type: 'object',
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
const apiKeyFromConfig = this.config.getApiKey();
|
|
||||||
this.ai = new GoogleGenAI({
|
|
||||||
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
|
|
||||||
});
|
|
||||||
this.modelName = this.config.getModel();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
validateParams(params: WebFetchToolParams): string | null {
|
validateParams(params: WebFetchToolParams): string | null {
|
||||||
|
@ -109,7 +99,7 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
||||||
|
|
||||||
async execute(
|
async execute(
|
||||||
params: WebFetchToolParams,
|
params: WebFetchToolParams,
|
||||||
_signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
): Promise<ToolResult> {
|
): Promise<ToolResult> {
|
||||||
const validationError = this.validateParams(params);
|
const validationError = this.validateParams(params);
|
||||||
if (validationError) {
|
if (validationError) {
|
||||||
|
@ -120,23 +110,14 @@ export class WebFetchTool extends BaseTool<WebFetchToolParams, ToolResult> {
|
||||||
}
|
}
|
||||||
|
|
||||||
const userPrompt = params.prompt;
|
const userPrompt = params.prompt;
|
||||||
|
const geminiClient = this.config.getGeminiClient();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const apiCall = () =>
|
const response = await geminiClient.generateContent(
|
||||||
this.ai.models.generateContent({
|
[{ role: 'user', parts: [{ text: userPrompt }] }],
|
||||||
model: this.modelName,
|
{ tools: [{ urlContext: {} }] },
|
||||||
contents: [
|
signal, // Pass signal
|
||||||
{
|
);
|
||||||
role: 'user',
|
|
||||||
parts: [{ text: userPrompt }],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
config: {
|
|
||||||
tools: [{ urlContext: {} }],
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const response = await retryWithBackoff(apiCall);
|
|
||||||
|
|
||||||
console.debug(
|
console.debug(
|
||||||
`[WebFetchTool] Full response for prompt "${userPrompt.substring(0, 50)}...":`,
|
`[WebFetchTool] Full response for prompt "${userPrompt.substring(0, 50)}...":`,
|
||||||
|
|
|
@ -4,14 +4,13 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { GoogleGenAI, GroundingMetadata } from '@google/genai';
|
import { GroundingMetadata } from '@google/genai';
|
||||||
import { BaseTool, ToolResult } from './tools.js';
|
import { BaseTool, ToolResult } from './tools.js';
|
||||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||||
|
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||||
import { retryWithBackoff } from '../utils/retry.js';
|
|
||||||
|
|
||||||
interface GroundingChunkWeb {
|
interface GroundingChunkWeb {
|
||||||
uri?: string;
|
uri?: string;
|
||||||
|
@ -64,9 +63,6 @@ export class WebSearchTool extends BaseTool<
|
||||||
> {
|
> {
|
||||||
static readonly Name: string = 'google_web_search';
|
static readonly Name: string = 'google_web_search';
|
||||||
|
|
||||||
private ai: GoogleGenAI;
|
|
||||||
private modelName: string;
|
|
||||||
|
|
||||||
constructor(private readonly config: Config) {
|
constructor(private readonly config: Config) {
|
||||||
super(
|
super(
|
||||||
WebSearchTool.Name,
|
WebSearchTool.Name,
|
||||||
|
@ -83,13 +79,6 @@ export class WebSearchTool extends BaseTool<
|
||||||
required: ['query'],
|
required: ['query'],
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
const apiKeyFromConfig = this.config.getApiKey();
|
|
||||||
// Initialize GoogleGenAI, allowing fallback to environment variables for API key
|
|
||||||
this.ai = new GoogleGenAI({
|
|
||||||
apiKey: apiKeyFromConfig === '' ? undefined : apiKeyFromConfig,
|
|
||||||
});
|
|
||||||
this.modelName = this.config.getModel();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
validateParams(params: WebSearchToolParams): string | null {
|
validateParams(params: WebSearchToolParams): string | null {
|
||||||
|
@ -112,7 +101,10 @@ export class WebSearchTool extends BaseTool<
|
||||||
return `Searching the web for: "${params.query}"`;
|
return `Searching the web for: "${params.query}"`;
|
||||||
}
|
}
|
||||||
|
|
||||||
async execute(params: WebSearchToolParams): Promise<WebSearchToolResult> {
|
async execute(
|
||||||
|
params: WebSearchToolParams,
|
||||||
|
signal: AbortSignal,
|
||||||
|
): Promise<WebSearchToolResult> {
|
||||||
const validationError = this.validateParams(params);
|
const validationError = this.validateParams(params);
|
||||||
if (validationError) {
|
if (validationError) {
|
||||||
return {
|
return {
|
||||||
|
@ -120,18 +112,14 @@ export class WebSearchTool extends BaseTool<
|
||||||
returnDisplay: validationError,
|
returnDisplay: validationError,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
const geminiClient = this.config.getGeminiClient();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const apiCall = () =>
|
const response = await geminiClient.generateContent(
|
||||||
this.ai.models.generateContent({
|
[{ role: 'user', parts: [{ text: params.query }] }],
|
||||||
model: this.modelName,
|
{ tools: [{ googleSearch: {} }] },
|
||||||
contents: [{ role: 'user', parts: [{ text: params.query }] }],
|
signal,
|
||||||
config: {
|
);
|
||||||
tools: [{ googleSearch: {} }],
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const response = await retryWithBackoff(apiCall);
|
|
||||||
|
|
||||||
const responseText = getResponseText(response);
|
const responseText = getResponseText(response);
|
||||||
const groundingMetadata = response.candidates?.[0]?.groundingMetadata;
|
const groundingMetadata = response.candidates?.[0]?.groundingMetadata;
|
||||||
|
|
Loading…
Reference in New Issue