From 357546a2aac918702f6ebfa4a97bd95ccd614e5d Mon Sep 17 00:00:00 2001 From: Tommaso Sciortino Date: Mon, 7 Jul 2025 15:01:59 -0700 Subject: [PATCH] Initialize MCP tools once at start up instead of every time we auth. (#3483) --- packages/cli/src/gemini.tsx | 17 +-- packages/core/src/code_assist/converter.ts | 1 + packages/core/src/config/config.ts | 125 ++++++++++----------- packages/core/src/core/contentGenerator.ts | 3 +- packages/core/src/tools/edit.test.ts | 10 +- packages/core/src/tools/edit.ts | 9 +- packages/core/src/tools/tool-registry.ts | 2 - packages/core/src/tools/write-file.ts | 8 +- 8 files changed, 76 insertions(+), 99 deletions(-) diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 11ae1505..39d3bbe3 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -115,15 +115,7 @@ export async function main() { setMaxSizedBoxDebugging(config.getDebugMode()); - // Initialize centralized FileDiscoveryService - config.getFileService(); - if (config.getCheckpointingEnabled()) { - try { - await config.getGitService(); - } catch { - // For now swallow the error, later log it. - } - } + await config.initialize(); if (settings.merged.theme) { if (!themeManager.setActiveTheme(settings.merged.theme)) { @@ -133,12 +125,11 @@ export async function main() { } } - const memoryArgs = settings.merged.autoConfigureMaxOldSpaceSize - ? getNodeMemoryArgs(config) - : []; - // hop into sandbox if we are outside and sandboxing is enabled if (!process.env.SANDBOX) { + const memoryArgs = settings.merged.autoConfigureMaxOldSpaceSize + ? getNodeMemoryArgs(config) + : []; const sandboxConfig = config.getSandbox(); if (sandboxConfig) { if (settings.merged.selectedAuthType) { diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts index b27617c4..8340cfc1 100644 --- a/packages/core/src/code_assist/converter.ts +++ b/packages/core/src/code_assist/converter.ts @@ -80,6 +80,7 @@ interface VertexGenerateContentResponse { promptFeedback?: GenerateContentResponsePromptFeedback; usageMetadata?: GenerateContentResponseUsageMetadata; } + export interface CaCountTokenRequest { request: VertexCountTokenRequest; } diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index fd96af91..ca0714f0 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -232,32 +232,30 @@ export class Config { } } + async initialize(): Promise { + // Initialize centralized FileDiscoveryService + this.getFileService(); + if (this.getCheckpointingEnabled()) { + try { + await this.getGitService(); + } catch { + // For now swallow the error, later log it. + } + } + this.toolRegistry = await this.createToolRegistry(); + } + async refreshAuth(authMethod: AuthType) { - // Always use the original default model when switching auth methods - // This ensures users don't stay on Flash after switching between auth types - // and allows API key users to get proper fallback behavior from getEffectiveModel - const modelToUse = this.model; // Use the original default model - - // Temporarily clear contentGeneratorConfig to prevent getModel() from returning - // the previous session's model (which might be Flash) - this.contentGeneratorConfig = undefined!; - - const contentConfig = await createContentGeneratorConfig( - modelToUse, + this.contentGeneratorConfig = await createContentGeneratorConfig( + this.model, authMethod, - this, ); - const gc = new GeminiClient(this); - this.geminiClient = gc; - this.toolRegistry = await createToolRegistry(this); - await gc.initialize(contentConfig); - this.contentGeneratorConfig = contentConfig; + this.geminiClient = new GeminiClient(this); + await this.geminiClient.initialize(this.contentGeneratorConfig); // Reset the session flag since we're explicitly changing auth and using default model this.modelSwitchedDuringSession = false; - - // Note: In the future, we may want to reset any cached state when switching auth methods } getSessionId(): string { @@ -469,58 +467,59 @@ export class Config { return { memoryContent, fileCount }; } -} -export function createToolRegistry(config: Config): Promise { - const registry = new ToolRegistry(config); - const targetDir = config.getTargetDir(); + async createToolRegistry(): Promise { + const registry = new ToolRegistry(this); + const targetDir = this.getTargetDir(); - // helper to create & register core tools that are enabled - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const registerCoreTool = (ToolClass: any, ...args: unknown[]) => { - const className = ToolClass.name; - const toolName = ToolClass.Name || className; - const coreTools = config.getCoreTools(); - const excludeTools = config.getExcludeTools(); + // helper to create & register core tools that are enabled + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const registerCoreTool = (ToolClass: any, ...args: unknown[]) => { + const className = ToolClass.name; + const toolName = ToolClass.Name || className; + const coreTools = this.getCoreTools(); + const excludeTools = this.getExcludeTools(); - let isEnabled = false; - if (coreTools === undefined) { - isEnabled = true; - } else { - isEnabled = coreTools.some( - (tool) => - tool === className || - tool === toolName || - tool.startsWith(`${className}(`) || - tool.startsWith(`${toolName}(`), - ); - } + let isEnabled = false; + if (coreTools === undefined) { + isEnabled = true; + } else { + isEnabled = coreTools.some( + (tool) => + tool === className || + tool === toolName || + tool.startsWith(`${className}(`) || + tool.startsWith(`${toolName}(`), + ); + } - if (excludeTools?.includes(className) || excludeTools?.includes(toolName)) { - isEnabled = false; - } + if ( + excludeTools?.includes(className) || + excludeTools?.includes(toolName) + ) { + isEnabled = false; + } - if (isEnabled) { - registry.registerTool(new ToolClass(...args)); - } - }; + if (isEnabled) { + registry.registerTool(new ToolClass(...args)); + } + }; + + registerCoreTool(LSTool, targetDir, this); + registerCoreTool(ReadFileTool, targetDir, this); + registerCoreTool(GrepTool, targetDir); + registerCoreTool(GlobTool, targetDir, this); + registerCoreTool(EditTool, this); + registerCoreTool(WriteFileTool, this); + registerCoreTool(WebFetchTool, this); + registerCoreTool(ReadManyFilesTool, targetDir, this); + registerCoreTool(ShellTool, this); + registerCoreTool(MemoryTool); + registerCoreTool(WebSearchTool, this); - registerCoreTool(LSTool, targetDir, config); - registerCoreTool(ReadFileTool, targetDir, config); - registerCoreTool(GrepTool, targetDir); - registerCoreTool(GlobTool, targetDir, config); - registerCoreTool(EditTool, config); - registerCoreTool(WriteFileTool, config); - registerCoreTool(WebFetchTool, config); - registerCoreTool(ReadManyFilesTool, targetDir, config); - registerCoreTool(ShellTool, config); - registerCoreTool(MemoryTool); - registerCoreTool(WebSearchTool, config); - return (async () => { await registry.discoverTools(); return registry; - })(); + } } - // Export model constants for use in CLI export { DEFAULT_GEMINI_FLASH_MODEL }; diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index f0c163d2..1b22333a 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -50,7 +50,6 @@ export type ContentGeneratorConfig = { export async function createContentGeneratorConfig( model: string | undefined, authType: AuthType | undefined, - config?: { getModel?: () => string }, ): Promise { const geminiApiKey = process.env.GEMINI_API_KEY; const googleApiKey = process.env.GOOGLE_API_KEY; @@ -58,7 +57,7 @@ export async function createContentGeneratorConfig( const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION; // Use runtime model from config if available, otherwise fallback to parameter or default - const effectiveModel = config?.getModel?.() || model || DEFAULT_GEMINI_MODEL; + const effectiveModel = model || DEFAULT_GEMINI_MODEL; const contentGeneratorConfig: ContentGeneratorConfig = { model: effectiveModel, diff --git a/packages/core/src/tools/edit.test.ts b/packages/core/src/tools/edit.test.ts index ab42450a..84ad1daf 100644 --- a/packages/core/src/tools/edit.test.ts +++ b/packages/core/src/tools/edit.test.ts @@ -38,21 +38,19 @@ describe('EditTool', () => { let tempDir: string; let rootDir: string; let mockConfig: Config; + let geminiClient: any; beforeEach(() => { tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'edit-tool-test-')); rootDir = path.join(tempDir, 'root'); fs.mkdirSync(rootDir); - // The client instance that EditTool will use - const mockClientInstanceWithGenerateJson = { + geminiClient = { generateJson: mockGenerateJson, // mockGenerateJson is already defined and hoisted }; mockConfig = { - getGeminiClient: vi - .fn() - .mockReturnValue(mockClientInstanceWithGenerateJson), + getGeminiClient: vi.fn().mockReturnValue(geminiClient), getTargetDir: () => rootDir, getApprovalMode: vi.fn(), setApprovalMode: vi.fn(), @@ -339,7 +337,7 @@ describe('EditTool', () => { mockCalled = true; expect(content).toBe(originalContent); expect(p).toBe(params); - expect(client).toBe((tool as any).client); + expect(client).toBe(geminiClient); return { params: { file_path: filePath, diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 2df01a22..f388b9f5 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -18,7 +18,6 @@ import { import { SchemaValidator } from '../utils/schemaValidator.js'; import { makeRelative, shortenPath } from '../utils/paths.js'; import { isNodeError } from '../utils/errors.js'; -import { GeminiClient } from '../core/client.js'; import { Config, ApprovalMode } from '../config/config.js'; import { ensureCorrectEdit } from '../utils/editCorrector.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; @@ -72,15 +71,13 @@ export class EditTool implements ModifiableTool { static readonly Name = 'replace'; - private readonly config: Config; private readonly rootDirectory: string; - private readonly client: GeminiClient; /** * Creates a new instance of the EditLogic * @param rootDirectory Root directory to ground this tool in. */ - constructor(config: Config) { + constructor(private readonly config: Config) { super( EditTool.Name, 'Edit', @@ -123,9 +120,7 @@ Expectation for required parameters: type: 'object', }, ); - this.config = config; this.rootDirectory = path.resolve(this.config.getTargetDir()); - this.client = config.getGeminiClient(); } /** @@ -239,7 +234,7 @@ Expectation for required parameters: params.file_path, currentContent, params, - this.client, + this.config.getGeminiClient(), abortSignal, ); finalOldString = correctedEdit.params.old_string; diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index bc628f03..1778c6d6 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -154,8 +154,6 @@ export class ToolRegistry { for (const tool of this.tools.values()) { if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) { this.tools.delete(tool.name); - } else { - // Keep manually registered tools } } diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index 37a1ba78..ab30891b 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -23,7 +23,6 @@ import { ensureCorrectEdit, ensureCorrectFileContent, } from '../utils/editCorrector.js'; -import { GeminiClient } from '../core/client.js'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; import { getSpecificMimeType } from '../utils/fileUtils.js'; @@ -67,7 +66,6 @@ export class WriteFileTool implements ModifiableTool { static readonly Name: string = 'write_file'; - private readonly client: GeminiClient; constructor(private readonly config: Config) { super( @@ -92,8 +90,6 @@ export class WriteFileTool type: 'object', }, ); - - this.client = this.config.getGeminiClient(); } /** @@ -374,7 +370,7 @@ export class WriteFileTool new_string: proposedContent, file_path: filePath, }, - this.client, + this.config.getGeminiClient(), abortSignal, ); correctedContent = correctedParams.new_string; @@ -382,7 +378,7 @@ export class WriteFileTool // This implies new file (ENOENT) correctedContent = await ensureCorrectFileContent( proposedContent, - this.client, + this.config.getGeminiClient(), abortSignal, ); }