Fix: Ensure MCP tools are discovered from slow-starting servers (#717)

This commit is contained in:
N. Taylor Mullen 2025-06-03 00:40:51 -07:00 committed by GitHub
parent 5f6f6a95a2
commit c71d6ddc3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 95 additions and 35 deletions

View File

@ -254,7 +254,7 @@ export function createServerConfig(params: ConfigParameters): Config {
});
}
function createToolRegistry(config: Config): Promise<ToolRegistry> {
export function createToolRegistry(config: Config): Promise<ToolRegistry> {
const registry = new ToolRegistry(config);
const targetDir = config.getTargetDir();
const tools = config.getCoreTools()
@ -281,12 +281,8 @@ function createToolRegistry(config: Config): Promise<ToolRegistry> {
registerCoreTool(ShellTool, config);
registerCoreTool(MemoryTool);
registerCoreTool(WebSearchTool, config);
// This is async, but we can't wait for it to finish because when we register
// discovered tools, we need to see if existing tools already exist in order to
// avoid duplicates.
registry.discoverTools();
// Maintain an async registry return so it's easy in the future to add async behavior to this instantiation.
return Promise.resolve(registry);
return (async () => {
await registry.discoverTools();
return registry;
})();
}

View File

@ -135,7 +135,11 @@ describe('discoverMcpTools', () => {
});
it('should do nothing if no MCP servers or command are configured', async () => {
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
expect(Client).not.toHaveBeenCalled();
@ -161,7 +165,11 @@ describe('discoverMcpTools', () => {
// In this case, listTools fails, so no tools are registered.
// The default mock `mockReturnValue([])` from beforeEach should apply.
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(parse).toHaveBeenCalledWith(commandString, process.env);
expect(StdioClientTransport).toHaveBeenCalledWith({
@ -204,7 +212,11 @@ describe('discoverMcpTools', () => {
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(StdioClientTransport).toHaveBeenCalledWith({
command: serverConfig.command,
@ -239,7 +251,11 @@ describe('discoverMcpTools', () => {
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
@ -317,7 +333,11 @@ describe('discoverMcpTools', () => {
},
);
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
@ -381,7 +401,11 @@ describe('discoverMcpTools', () => {
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
const registeredTool = mockToolRegistry.registerTool.mock
@ -410,9 +434,13 @@ describe('discoverMcpTools', () => {
});
vi.spyOn(console, 'error').mockImplementation(() => {});
await expect(discoverMcpTools(mockConfig)).rejects.toThrow(
'Parsing failed',
);
await expect(
discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
),
).rejects.toThrow('Parsing failed');
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
expect(console.error).not.toHaveBeenCalled();
});
@ -421,7 +449,11 @@ describe('discoverMcpTools', () => {
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
vi.spyOn(console, 'error').mockImplementation(() => {});
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@ -442,7 +474,11 @@ describe('discoverMcpTools', () => {
);
vi.spyOn(console, 'error').mockImplementation(() => {});
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@ -463,7 +499,11 @@ describe('discoverMcpTools', () => {
);
vi.spyOn(console, 'error').mockImplementation(() => {});
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(console.error).toHaveBeenCalledWith(
expect.stringContaining(
@ -483,7 +523,11 @@ describe('discoverMcpTools', () => {
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(mockConfig);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
const clientInstances = vi.mocked(Client).mock.results;
expect(clientInstances.length).toBeGreaterThan(0);

View File

@ -8,15 +8,18 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { parse } from 'shell-quote';
import { Config, MCPServerConfig } from '../config/config.js';
import { MCPServerConfig } from '../config/config.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai';
import { ToolRegistry } from './tool-registry.js';
export async function discoverMcpTools(config: Config): Promise<void> {
const mcpServers = config.getMcpServers() || {};
if (config.getMcpServerCommand()) {
const cmd = config.getMcpServerCommand()!;
export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry,
): Promise<void> {
if (mcpServerCommand) {
const cmd = mcpServerCommand;
const args = parse(cmd, process.env) as string[];
if (args.some((arg) => typeof arg !== 'string')) {
throw new Error('failed to parse mcpServerCommand: ' + cmd);
@ -30,7 +33,7 @@ export async function discoverMcpTools(config: Config): Promise<void> {
const discoveryPromises = Object.entries(mcpServers).map(
([mcpServerName, mcpServerConfig]) =>
connectAndDiscover(mcpServerName, mcpServerConfig, config),
connectAndDiscover(mcpServerName, mcpServerConfig, toolRegistry),
);
await Promise.all(discoveryPromises);
}
@ -38,7 +41,7 @@ export async function discoverMcpTools(config: Config): Promise<void> {
async function connectAndDiscover(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
config: Config,
toolRegistry: ToolRegistry,
): Promise<void> {
let transport;
if (mcpServerConfig.url) {
@ -90,7 +93,6 @@ async function connectAndDiscover(
});
}
const toolRegistry = await config.getToolRegistry();
try {
const mcpCallableTool: CallableTool = mcpToTool(mcpClient);
const discoveredToolFunctions = await mcpCallableTool.tool();

View File

@ -277,7 +277,11 @@ describe('ToolRegistry', () => {
await toolRegistry.discoverTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
mcpServerConfigVal,
undefined,
toolRegistry,
);
// We no longer check these as discoverMcpTools is mocked
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
// expect(Client).toHaveBeenCalledTimes(1);
@ -302,7 +306,11 @@ describe('ToolRegistry', () => {
);
await toolRegistry.discoverTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
{},
'mcp-server-start-command --param',
toolRegistry,
);
});
it('should handle errors during MCP client connection gracefully and close transport', async () => {
@ -314,7 +322,13 @@ describe('ToolRegistry', () => {
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
await toolRegistry.discoverTools();
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
{
'failing-mcp': { command: 'fail-cmd' },
},
undefined,
toolRegistry,
);
expect(toolRegistry.getAllTools()).toHaveLength(0);
});
});

View File

@ -161,7 +161,11 @@ export class ToolRegistry {
}
}
// discover tools using MCP servers, if configured
await discoverMcpTools(this.config);
await discoverMcpTools(
this.config.getMcpServers() ?? {},
this.config.getMcpServerCommand(),
this,
);
}
/**