From c5869db0806d04bc0d1f4da6823f9e13d22e476b Mon Sep 17 00:00:00 2001 From: Olcan Date: Mon, 2 Jun 2025 09:56:32 -0700 Subject: [PATCH] enable async tool discovery by making the registry accessor async; remove call to discoverTools that caused duplicate discovery (#691) --- packages/cli/src/nonInteractiveCli.test.ts | 2 -- packages/cli/src/nonInteractiveCli.ts | 3 +-- packages/cli/src/ui/hooks/atCommandProcessor.ts | 2 +- packages/core/src/config/config.ts | 12 +++++++----- packages/core/src/core/client.ts | 12 ++++++------ packages/core/src/core/coreToolScheduler.ts | 7 ++++--- packages/core/src/tools/tool-registry.ts | 3 ++- 7 files changed, 21 insertions(+), 20 deletions(-) diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index dca3b855..389d35f2 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -42,7 +42,6 @@ describe('runNonInteractive', () => { startChat: vi.fn().mockResolvedValue(mockChat), } as unknown as GeminiClient; mockToolRegistry = { - discoverTools: vi.fn().mockResolvedValue(undefined), getFunctionDeclarations: vi.fn().mockReturnValue([]), getTool: vi.fn(), } as unknown as ToolRegistry; @@ -82,7 +81,6 @@ describe('runNonInteractive', () => { await runNonInteractive(mockConfig, 'Test input'); expect(mockGeminiClient.startChat).toHaveBeenCalled(); - expect(mockToolRegistry.discoverTools).toHaveBeenCalled(); expect(mockChat.sendMessageStream).toHaveBeenCalledWith({ message: [{ text: 'Test input' }], config: { diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 9077ecbf..f7b4108b 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -40,8 +40,7 @@ export async function runNonInteractive( input: string, ): Promise { const geminiClient = new GeminiClient(config); - const toolRegistry: ToolRegistry = config.getToolRegistry(); - await toolRegistry.discoverTools(); + const toolRegistry: ToolRegistry = await config.getToolRegistry(); const chat = await geminiClient.startChat(); const abortController = new AbortController(); diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.ts b/packages/cli/src/ui/hooks/atCommandProcessor.ts index 54b10d51..ac56ab75 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.ts @@ -138,7 +138,7 @@ export async function handleAtCommand({ const atPathToResolvedSpecMap = new Map(); const contentLabelsForDisplay: string[] = []; - const toolRegistry = config.getToolRegistry(); + const toolRegistry = await config.getToolRegistry(); const readManyFilesTool = toolRegistry.getTool('read_many_files'); const globTool = toolRegistry.getTool('glob'); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index d918de04..a6279e2e 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -60,7 +60,7 @@ export interface ConfigParameters { } export class Config { - private toolRegistry: ToolRegistry; + private toolRegistry: Promise; private readonly apiKey: string; private readonly model: string; private readonly sandbox: boolean | string; @@ -124,7 +124,7 @@ export class Config { return this.targetDir; } - getToolRegistry(): ToolRegistry { + async getToolRegistry(): Promise { return this.toolRegistry; } @@ -232,7 +232,7 @@ export function createServerConfig(params: ConfigParameters): Config { }); } -export function createToolRegistry(config: Config): ToolRegistry { +export function createToolRegistry(config: Config): Promise { const registry = new ToolRegistry(config); const targetDir = config.getTargetDir(); const tools = config.getCoreTools() @@ -259,6 +259,8 @@ export function createToolRegistry(config: Config): ToolRegistry { registerCoreTool(ShellTool, config); registerCoreTool(MemoryTool); registerCoreTool(WebSearchTool, config); - registry.discoverTools(); - return registry; + return (async () => { + await registry.discoverTools(); + return registry; + })(); } diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 9006c675..db30ac16 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -70,13 +70,14 @@ export class GeminiClient { `.trim(); const initialParts: Part[] = [{ text: context }]; + const toolRegistry = await this.config.getToolRegistry(); // Add full file context if the flag is set if (this.config.getFullContext()) { try { - const readManyFilesTool = this.config - .getToolRegistry() - .getTool('read_many_files') as ReadManyFilesTool; + const readManyFilesTool = toolRegistry.getTool( + 'read_many_files', + ) as ReadManyFilesTool; if (readManyFilesTool) { // Read all files in the target directory const result = await readManyFilesTool.execute( @@ -114,9 +115,8 @@ export class GeminiClient { async startChat(): Promise { const envParts = await this.getEnvironment(); - const toolDeclarations = this.config - .getToolRegistry() - .getFunctionDeclarations(); + const toolRegistry = await this.config.getToolRegistry(); + const toolDeclarations = toolRegistry.getFunctionDeclarations(); const tools: Tool[] = [{ functionDeclarations: toolDeclarations }]; const history: Content[] = [ { diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 1278d468..58f821c5 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -155,14 +155,14 @@ const createErrorResponse = ( }); interface CoreToolSchedulerOptions { - toolRegistry: ToolRegistry; + toolRegistry: Promise; outputUpdateHandler?: OutputUpdateHandler; onAllToolCallsComplete?: AllToolCallsCompleteHandler; onToolCallsUpdate?: ToolCallsUpdateHandler; } export class CoreToolScheduler { - private toolRegistry: ToolRegistry; + private toolRegistry: Promise; private toolCalls: ToolCall[] = []; private abortController: AbortController; private outputUpdateHandler?: OutputUpdateHandler; @@ -295,10 +295,11 @@ export class CoreToolScheduler { ); } const requestsToProcess = Array.isArray(request) ? request : [request]; + const toolRegistry = await this.toolRegistry; const newToolCalls: ToolCall[] = requestsToProcess.map( (reqInfo): ToolCall => { - const toolInstance = this.toolRegistry.getTool(reqInfo.name); + const toolInstance = toolRegistry.getTool(reqInfo.name); if (!toolInstance) { return { status: 'error', diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index e241ada5..384552ca 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -100,6 +100,7 @@ Signal: Signal number or \`(none)\` if no signal was received. export class ToolRegistry { private tools: Map = new Map(); + private discovery: Promise | null = null; private config: Config; constructor(config: Config) { @@ -121,7 +122,7 @@ export class ToolRegistry { } /** - * Discovers tools from project, if a discovery command is configured. + * Discovers tools from project (if available and configured). * Can be called multiple times to update discovered tools. */ async discoverTools(): Promise {