Fix: Ensure MCP tools are discovered from slow-starting servers (#717)
This commit is contained in:
parent
5f6f6a95a2
commit
c71d6ddc3b
|
@ -254,7 +254,7 @@ export function createServerConfig(params: ConfigParameters): Config {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function createToolRegistry(config: Config): Promise<ToolRegistry> {
|
export function createToolRegistry(config: Config): Promise<ToolRegistry> {
|
||||||
const registry = new ToolRegistry(config);
|
const registry = new ToolRegistry(config);
|
||||||
const targetDir = config.getTargetDir();
|
const targetDir = config.getTargetDir();
|
||||||
const tools = config.getCoreTools()
|
const tools = config.getCoreTools()
|
||||||
|
@ -281,12 +281,8 @@ function createToolRegistry(config: Config): Promise<ToolRegistry> {
|
||||||
registerCoreTool(ShellTool, config);
|
registerCoreTool(ShellTool, config);
|
||||||
registerCoreTool(MemoryTool);
|
registerCoreTool(MemoryTool);
|
||||||
registerCoreTool(WebSearchTool, config);
|
registerCoreTool(WebSearchTool, config);
|
||||||
|
return (async () => {
|
||||||
// This is async, but we can't wait for it to finish because when we register
|
await registry.discoverTools();
|
||||||
// discovered tools, we need to see if existing tools already exist in order to
|
return registry;
|
||||||
// avoid duplicates.
|
})();
|
||||||
registry.discoverTools();
|
|
||||||
|
|
||||||
// Maintain an async registry return so it's easy in the future to add async behavior to this instantiation.
|
|
||||||
return Promise.resolve(registry);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,7 +135,11 @@ describe('discoverMcpTools', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should do nothing if no MCP servers or command are configured', async () => {
|
it('should do nothing if no MCP servers or command are configured', async () => {
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
|
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
|
||||||
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
|
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
|
||||||
expect(Client).not.toHaveBeenCalled();
|
expect(Client).not.toHaveBeenCalled();
|
||||||
|
@ -161,7 +165,11 @@ describe('discoverMcpTools', () => {
|
||||||
// In this case, listTools fails, so no tools are registered.
|
// In this case, listTools fails, so no tools are registered.
|
||||||
// The default mock `mockReturnValue([])` from beforeEach should apply.
|
// The default mock `mockReturnValue([])` from beforeEach should apply.
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
expect(parse).toHaveBeenCalledWith(commandString, process.env);
|
expect(parse).toHaveBeenCalledWith(commandString, process.env);
|
||||||
expect(StdioClientTransport).toHaveBeenCalledWith({
|
expect(StdioClientTransport).toHaveBeenCalledWith({
|
||||||
|
@ -204,7 +212,11 @@ describe('discoverMcpTools', () => {
|
||||||
expect.any(DiscoveredMCPTool),
|
expect.any(DiscoveredMCPTool),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
expect(StdioClientTransport).toHaveBeenCalledWith({
|
expect(StdioClientTransport).toHaveBeenCalledWith({
|
||||||
command: serverConfig.command,
|
command: serverConfig.command,
|
||||||
|
@ -239,7 +251,11 @@ describe('discoverMcpTools', () => {
|
||||||
expect.any(DiscoveredMCPTool),
|
expect.any(DiscoveredMCPTool),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
|
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
|
||||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
|
||||||
|
@ -317,7 +333,11 @@ describe('discoverMcpTools', () => {
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
|
||||||
const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
|
const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
|
||||||
|
@ -381,7 +401,11 @@ describe('discoverMcpTools', () => {
|
||||||
expect.any(DiscoveredMCPTool),
|
expect.any(DiscoveredMCPTool),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
|
||||||
const registeredTool = mockToolRegistry.registerTool.mock
|
const registeredTool = mockToolRegistry.registerTool.mock
|
||||||
|
@ -410,9 +434,13 @@ describe('discoverMcpTools', () => {
|
||||||
});
|
});
|
||||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||||
|
|
||||||
await expect(discoverMcpTools(mockConfig)).rejects.toThrow(
|
await expect(
|
||||||
'Parsing failed',
|
discoverMcpTools(
|
||||||
);
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
),
|
||||||
|
).rejects.toThrow('Parsing failed');
|
||||||
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
|
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
|
||||||
expect(console.error).not.toHaveBeenCalled();
|
expect(console.error).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
@ -421,7 +449,11 @@ describe('discoverMcpTools', () => {
|
||||||
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
|
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
|
||||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
expect(console.error).toHaveBeenCalledWith(
|
expect(console.error).toHaveBeenCalledWith(
|
||||||
expect.stringContaining(
|
expect.stringContaining(
|
||||||
|
@ -442,7 +474,11 @@ describe('discoverMcpTools', () => {
|
||||||
);
|
);
|
||||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
expect(console.error).toHaveBeenCalledWith(
|
expect(console.error).toHaveBeenCalledWith(
|
||||||
expect.stringContaining(
|
expect.stringContaining(
|
||||||
|
@ -463,7 +499,11 @@ describe('discoverMcpTools', () => {
|
||||||
);
|
);
|
||||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
expect(console.error).toHaveBeenCalledWith(
|
expect(console.error).toHaveBeenCalledWith(
|
||||||
expect.stringContaining(
|
expect.stringContaining(
|
||||||
|
@ -483,7 +523,11 @@ describe('discoverMcpTools', () => {
|
||||||
expect.any(DiscoveredMCPTool),
|
expect.any(DiscoveredMCPTool),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
await discoverMcpTools(mockConfig);
|
await discoverMcpTools(
|
||||||
|
mockConfig.getMcpServers() ?? {},
|
||||||
|
mockConfig.getMcpServerCommand(),
|
||||||
|
mockToolRegistry as any,
|
||||||
|
);
|
||||||
|
|
||||||
const clientInstances = vi.mocked(Client).mock.results;
|
const clientInstances = vi.mocked(Client).mock.results;
|
||||||
expect(clientInstances.length).toBeGreaterThan(0);
|
expect(clientInstances.length).toBeGreaterThan(0);
|
||||||
|
|
|
@ -8,15 +8,18 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||||
import { parse } from 'shell-quote';
|
import { parse } from 'shell-quote';
|
||||||
import { Config, MCPServerConfig } from '../config/config.js';
|
import { MCPServerConfig } from '../config/config.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai';
|
import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||||
|
import { ToolRegistry } from './tool-registry.js';
|
||||||
|
|
||||||
export async function discoverMcpTools(config: Config): Promise<void> {
|
export async function discoverMcpTools(
|
||||||
const mcpServers = config.getMcpServers() || {};
|
mcpServers: Record<string, MCPServerConfig>,
|
||||||
|
mcpServerCommand: string | undefined,
|
||||||
if (config.getMcpServerCommand()) {
|
toolRegistry: ToolRegistry,
|
||||||
const cmd = config.getMcpServerCommand()!;
|
): Promise<void> {
|
||||||
|
if (mcpServerCommand) {
|
||||||
|
const cmd = mcpServerCommand;
|
||||||
const args = parse(cmd, process.env) as string[];
|
const args = parse(cmd, process.env) as string[];
|
||||||
if (args.some((arg) => typeof arg !== 'string')) {
|
if (args.some((arg) => typeof arg !== 'string')) {
|
||||||
throw new Error('failed to parse mcpServerCommand: ' + cmd);
|
throw new Error('failed to parse mcpServerCommand: ' + cmd);
|
||||||
|
@ -30,7 +33,7 @@ export async function discoverMcpTools(config: Config): Promise<void> {
|
||||||
|
|
||||||
const discoveryPromises = Object.entries(mcpServers).map(
|
const discoveryPromises = Object.entries(mcpServers).map(
|
||||||
([mcpServerName, mcpServerConfig]) =>
|
([mcpServerName, mcpServerConfig]) =>
|
||||||
connectAndDiscover(mcpServerName, mcpServerConfig, config),
|
connectAndDiscover(mcpServerName, mcpServerConfig, toolRegistry),
|
||||||
);
|
);
|
||||||
await Promise.all(discoveryPromises);
|
await Promise.all(discoveryPromises);
|
||||||
}
|
}
|
||||||
|
@ -38,7 +41,7 @@ export async function discoverMcpTools(config: Config): Promise<void> {
|
||||||
async function connectAndDiscover(
|
async function connectAndDiscover(
|
||||||
mcpServerName: string,
|
mcpServerName: string,
|
||||||
mcpServerConfig: MCPServerConfig,
|
mcpServerConfig: MCPServerConfig,
|
||||||
config: Config,
|
toolRegistry: ToolRegistry,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
let transport;
|
let transport;
|
||||||
if (mcpServerConfig.url) {
|
if (mcpServerConfig.url) {
|
||||||
|
@ -90,7 +93,6 @@ async function connectAndDiscover(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
const toolRegistry = await config.getToolRegistry();
|
|
||||||
try {
|
try {
|
||||||
const mcpCallableTool: CallableTool = mcpToTool(mcpClient);
|
const mcpCallableTool: CallableTool = mcpToTool(mcpClient);
|
||||||
const discoveredToolFunctions = await mcpCallableTool.tool();
|
const discoveredToolFunctions = await mcpCallableTool.tool();
|
||||||
|
|
|
@ -277,7 +277,11 @@ describe('ToolRegistry', () => {
|
||||||
|
|
||||||
await toolRegistry.discoverTools();
|
await toolRegistry.discoverTools();
|
||||||
|
|
||||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
|
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||||
|
mcpServerConfigVal,
|
||||||
|
undefined,
|
||||||
|
toolRegistry,
|
||||||
|
);
|
||||||
// We no longer check these as discoverMcpTools is mocked
|
// We no longer check these as discoverMcpTools is mocked
|
||||||
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
|
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
|
||||||
// expect(Client).toHaveBeenCalledTimes(1);
|
// expect(Client).toHaveBeenCalledTimes(1);
|
||||||
|
@ -302,7 +306,11 @@ describe('ToolRegistry', () => {
|
||||||
);
|
);
|
||||||
|
|
||||||
await toolRegistry.discoverTools();
|
await toolRegistry.discoverTools();
|
||||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
|
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||||
|
{},
|
||||||
|
'mcp-server-start-command --param',
|
||||||
|
toolRegistry,
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle errors during MCP client connection gracefully and close transport', async () => {
|
it('should handle errors during MCP client connection gracefully and close transport', async () => {
|
||||||
|
@ -314,7 +322,13 @@ describe('ToolRegistry', () => {
|
||||||
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
|
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
|
||||||
|
|
||||||
await toolRegistry.discoverTools();
|
await toolRegistry.discoverTools();
|
||||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
|
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
'failing-mcp': { command: 'fail-cmd' },
|
||||||
|
},
|
||||||
|
undefined,
|
||||||
|
toolRegistry,
|
||||||
|
);
|
||||||
expect(toolRegistry.getAllTools()).toHaveLength(0);
|
expect(toolRegistry.getAllTools()).toHaveLength(0);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -161,7 +161,11 @@ export class ToolRegistry {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// discover tools using MCP servers, if configured
|
// discover tools using MCP servers, if configured
|
||||||
await discoverMcpTools(this.config);
|
await discoverMcpTools(
|
||||||
|
this.config.getMcpServers() ?? {},
|
||||||
|
this.config.getMcpServerCommand(),
|
||||||
|
this,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue