diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 7b0be08d..46e5123c 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -254,7 +254,7 @@ export function createServerConfig(params: ConfigParameters): Config { }); } -function createToolRegistry(config: Config): Promise { +export function createToolRegistry(config: Config): Promise { const registry = new ToolRegistry(config); const targetDir = config.getTargetDir(); const tools = config.getCoreTools() @@ -281,12 +281,8 @@ function createToolRegistry(config: Config): Promise { registerCoreTool(ShellTool, config); registerCoreTool(MemoryTool); registerCoreTool(WebSearchTool, config); - - // This is async, but we can't wait for it to finish because when we register - // discovered tools, we need to see if existing tools already exist in order to - // avoid duplicates. - registry.discoverTools(); - - // Maintain an async registry return so it's easy in the future to add async behavior to this instantiation. - return Promise.resolve(registry); + return (async () => { + await registry.discoverTools(); + return registry; + })(); } diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 121cd1d8..abd9c58f 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -135,7 +135,11 @@ describe('discoverMcpTools', () => { }); it('should do nothing if no MCP servers or command are configured', async () => { - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1); expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1); expect(Client).not.toHaveBeenCalled(); @@ -161,7 +165,11 @@ describe('discoverMcpTools', () => { // In this case, listTools fails, so no tools are registered. // The default mock `mockReturnValue([])` from beforeEach should apply. - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(parse).toHaveBeenCalledWith(commandString, process.env); expect(StdioClientTransport).toHaveBeenCalledWith({ @@ -204,7 +212,11 @@ describe('discoverMcpTools', () => { expect.any(DiscoveredMCPTool), ]); - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(StdioClientTransport).toHaveBeenCalledWith({ command: serverConfig.command, @@ -239,7 +251,11 @@ describe('discoverMcpTools', () => { expect.any(DiscoveredMCPTool), ]); - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!)); expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( @@ -317,7 +333,11 @@ describe('discoverMcpTools', () => { }, ); - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3); const registeredArgs = mockToolRegistry.registerTool.mock.calls.map( @@ -381,7 +401,11 @@ describe('discoverMcpTools', () => { expect.any(DiscoveredMCPTool), ]); - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); const registeredTool = mockToolRegistry.registerTool.mock @@ -410,9 +434,13 @@ describe('discoverMcpTools', () => { }); vi.spyOn(console, 'error').mockImplementation(() => {}); - await expect(discoverMcpTools(mockConfig)).rejects.toThrow( - 'Parsing failed', - ); + await expect( + discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ), + ).rejects.toThrow('Parsing failed'); expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); expect(console.error).not.toHaveBeenCalled(); }); @@ -421,7 +449,11 @@ describe('discoverMcpTools', () => { mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any }); vi.spyOn(console, 'error').mockImplementation(() => {}); - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(console.error).toHaveBeenCalledWith( expect.stringContaining( @@ -442,7 +474,11 @@ describe('discoverMcpTools', () => { ); vi.spyOn(console, 'error').mockImplementation(() => {}); - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(console.error).toHaveBeenCalledWith( expect.stringContaining( @@ -463,7 +499,11 @@ describe('discoverMcpTools', () => { ); vi.spyOn(console, 'error').mockImplementation(() => {}); - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); expect(console.error).toHaveBeenCalledWith( expect.stringContaining( @@ -483,7 +523,11 @@ describe('discoverMcpTools', () => { expect.any(DiscoveredMCPTool), ]); - await discoverMcpTools(mockConfig); + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); const clientInstances = vi.mocked(Client).mock.results; expect(clientInstances.length).toBeGreaterThan(0); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 87835219..1b7823c7 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -8,15 +8,18 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import { parse } from 'shell-quote'; -import { Config, MCPServerConfig } from '../config/config.js'; +import { MCPServerConfig } from '../config/config.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai'; +import { ToolRegistry } from './tool-registry.js'; -export async function discoverMcpTools(config: Config): Promise { - const mcpServers = config.getMcpServers() || {}; - - if (config.getMcpServerCommand()) { - const cmd = config.getMcpServerCommand()!; +export async function discoverMcpTools( + mcpServers: Record, + mcpServerCommand: string | undefined, + toolRegistry: ToolRegistry, +): Promise { + if (mcpServerCommand) { + const cmd = mcpServerCommand; const args = parse(cmd, process.env) as string[]; if (args.some((arg) => typeof arg !== 'string')) { throw new Error('failed to parse mcpServerCommand: ' + cmd); @@ -30,7 +33,7 @@ export async function discoverMcpTools(config: Config): Promise { const discoveryPromises = Object.entries(mcpServers).map( ([mcpServerName, mcpServerConfig]) => - connectAndDiscover(mcpServerName, mcpServerConfig, config), + connectAndDiscover(mcpServerName, mcpServerConfig, toolRegistry), ); await Promise.all(discoveryPromises); } @@ -38,7 +41,7 @@ export async function discoverMcpTools(config: Config): Promise { async function connectAndDiscover( mcpServerName: string, mcpServerConfig: MCPServerConfig, - config: Config, + toolRegistry: ToolRegistry, ): Promise { let transport; if (mcpServerConfig.url) { @@ -90,7 +93,6 @@ async function connectAndDiscover( }); } - const toolRegistry = await config.getToolRegistry(); try { const mcpCallableTool: CallableTool = mcpToTool(mcpClient); const discoveredToolFunctions = await mcpCallableTool.tool(); diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 1fb2df4e..f57f5bce 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -277,7 +277,11 @@ describe('ToolRegistry', () => { await toolRegistry.discoverTools(); - expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config); + expect(mockDiscoverMcpTools).toHaveBeenCalledWith( + mcpServerConfigVal, + undefined, + toolRegistry, + ); // We no longer check these as discoverMcpTools is mocked // expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1); // expect(Client).toHaveBeenCalledTimes(1); @@ -302,7 +306,11 @@ describe('ToolRegistry', () => { ); await toolRegistry.discoverTools(); - expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config); + expect(mockDiscoverMcpTools).toHaveBeenCalledWith( + {}, + 'mcp-server-start-command --param', + toolRegistry, + ); }); it('should handle errors during MCP client connection gracefully and close transport', async () => { @@ -314,7 +322,13 @@ describe('ToolRegistry', () => { mockMcpClientConnect.mockRejectedValue(new Error('Connection failed')); await toolRegistry.discoverTools(); - expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config); + expect(mockDiscoverMcpTools).toHaveBeenCalledWith( + { + 'failing-mcp': { command: 'fail-cmd' }, + }, + undefined, + toolRegistry, + ); expect(toolRegistry.getAllTools()).toHaveLength(0); }); }); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index 12aa1a83..2b27a703 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -161,7 +161,11 @@ export class ToolRegistry { } } // discover tools using MCP servers, if configured - await discoverMcpTools(this.config); + await discoverMcpTools( + this.config.getMcpServers() ?? {}, + this.config.getMcpServerCommand(), + this, + ); } /**