Fix: Ensure MCP tools are discovered from slow-starting servers (#717)
This commit is contained in:
parent
5f6f6a95a2
commit
c71d6ddc3b
|
@ -254,7 +254,7 @@ export function createServerConfig(params: ConfigParameters): Config {
|
|||
});
|
||||
}
|
||||
|
||||
function createToolRegistry(config: Config): Promise<ToolRegistry> {
|
||||
export function createToolRegistry(config: Config): Promise<ToolRegistry> {
|
||||
const registry = new ToolRegistry(config);
|
||||
const targetDir = config.getTargetDir();
|
||||
const tools = config.getCoreTools()
|
||||
|
@ -281,12 +281,8 @@ function createToolRegistry(config: Config): Promise<ToolRegistry> {
|
|||
registerCoreTool(ShellTool, config);
|
||||
registerCoreTool(MemoryTool);
|
||||
registerCoreTool(WebSearchTool, config);
|
||||
|
||||
// This is async, but we can't wait for it to finish because when we register
|
||||
// discovered tools, we need to see if existing tools already exist in order to
|
||||
// avoid duplicates.
|
||||
registry.discoverTools();
|
||||
|
||||
// Maintain an async registry return so it's easy in the future to add async behavior to this instantiation.
|
||||
return Promise.resolve(registry);
|
||||
return (async () => {
|
||||
await registry.discoverTools();
|
||||
return registry;
|
||||
})();
|
||||
}
|
||||
|
|
|
@ -135,7 +135,11 @@ describe('discoverMcpTools', () => {
|
|||
});
|
||||
|
||||
it('should do nothing if no MCP servers or command are configured', async () => {
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
|
||||
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
|
||||
expect(Client).not.toHaveBeenCalled();
|
||||
|
@ -161,7 +165,11 @@ describe('discoverMcpTools', () => {
|
|||
// In this case, listTools fails, so no tools are registered.
|
||||
// The default mock `mockReturnValue([])` from beforeEach should apply.
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(parse).toHaveBeenCalledWith(commandString, process.env);
|
||||
expect(StdioClientTransport).toHaveBeenCalledWith({
|
||||
|
@ -204,7 +212,11 @@ describe('discoverMcpTools', () => {
|
|||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(StdioClientTransport).toHaveBeenCalledWith({
|
||||
command: serverConfig.command,
|
||||
|
@ -239,7 +251,11 @@ describe('discoverMcpTools', () => {
|
|||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
|
||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
|
||||
|
@ -317,7 +333,11 @@ describe('discoverMcpTools', () => {
|
|||
},
|
||||
);
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
|
||||
const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
|
||||
|
@ -381,7 +401,11 @@ describe('discoverMcpTools', () => {
|
|||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
|
||||
const registeredTool = mockToolRegistry.registerTool.mock
|
||||
|
@ -410,9 +434,13 @@ describe('discoverMcpTools', () => {
|
|||
});
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
await expect(discoverMcpTools(mockConfig)).rejects.toThrow(
|
||||
'Parsing failed',
|
||||
);
|
||||
await expect(
|
||||
discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
),
|
||||
).rejects.toThrow('Parsing failed');
|
||||
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
|
||||
expect(console.error).not.toHaveBeenCalled();
|
||||
});
|
||||
|
@ -421,7 +449,11 @@ describe('discoverMcpTools', () => {
|
|||
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
|
@ -442,7 +474,11 @@ describe('discoverMcpTools', () => {
|
|||
);
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
|
@ -463,7 +499,11 @@ describe('discoverMcpTools', () => {
|
|||
);
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(console.error).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
|
@ -483,7 +523,11 @@ describe('discoverMcpTools', () => {
|
|||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(mockConfig);
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
const clientInstances = vi.mocked(Client).mock.results;
|
||||
expect(clientInstances.length).toBeGreaterThan(0);
|
||||
|
|
|
@ -8,15 +8,18 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
|||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { MCPServerConfig } from '../config/config.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { CallableTool, FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||
import { ToolRegistry } from './tool-registry.js';
|
||||
|
||||
export async function discoverMcpTools(config: Config): Promise<void> {
|
||||
const mcpServers = config.getMcpServers() || {};
|
||||
|
||||
if (config.getMcpServerCommand()) {
|
||||
const cmd = config.getMcpServerCommand()!;
|
||||
export async function discoverMcpTools(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
mcpServerCommand: string | undefined,
|
||||
toolRegistry: ToolRegistry,
|
||||
): Promise<void> {
|
||||
if (mcpServerCommand) {
|
||||
const cmd = mcpServerCommand;
|
||||
const args = parse(cmd, process.env) as string[];
|
||||
if (args.some((arg) => typeof arg !== 'string')) {
|
||||
throw new Error('failed to parse mcpServerCommand: ' + cmd);
|
||||
|
@ -30,7 +33,7 @@ export async function discoverMcpTools(config: Config): Promise<void> {
|
|||
|
||||
const discoveryPromises = Object.entries(mcpServers).map(
|
||||
([mcpServerName, mcpServerConfig]) =>
|
||||
connectAndDiscover(mcpServerName, mcpServerConfig, config),
|
||||
connectAndDiscover(mcpServerName, mcpServerConfig, toolRegistry),
|
||||
);
|
||||
await Promise.all(discoveryPromises);
|
||||
}
|
||||
|
@ -38,7 +41,7 @@ export async function discoverMcpTools(config: Config): Promise<void> {
|
|||
async function connectAndDiscover(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
config: Config,
|
||||
toolRegistry: ToolRegistry,
|
||||
): Promise<void> {
|
||||
let transport;
|
||||
if (mcpServerConfig.url) {
|
||||
|
@ -90,7 +93,6 @@ async function connectAndDiscover(
|
|||
});
|
||||
}
|
||||
|
||||
const toolRegistry = await config.getToolRegistry();
|
||||
try {
|
||||
const mcpCallableTool: CallableTool = mcpToTool(mcpClient);
|
||||
const discoveredToolFunctions = await mcpCallableTool.tool();
|
||||
|
|
|
@ -277,7 +277,11 @@ describe('ToolRegistry', () => {
|
|||
|
||||
await toolRegistry.discoverTools();
|
||||
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
mcpServerConfigVal,
|
||||
undefined,
|
||||
toolRegistry,
|
||||
);
|
||||
// We no longer check these as discoverMcpTools is mocked
|
||||
// expect(vi.mocked(mcpToTool)).toHaveBeenCalledTimes(1);
|
||||
// expect(Client).toHaveBeenCalledTimes(1);
|
||||
|
@ -302,7 +306,11 @@ describe('ToolRegistry', () => {
|
|||
);
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
{},
|
||||
'mcp-server-start-command --param',
|
||||
toolRegistry,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors during MCP client connection gracefully and close transport', async () => {
|
||||
|
@ -314,7 +322,13 @@ describe('ToolRegistry', () => {
|
|||
mockMcpClientConnect.mockRejectedValue(new Error('Connection failed'));
|
||||
|
||||
await toolRegistry.discoverTools();
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(config);
|
||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
||||
{
|
||||
'failing-mcp': { command: 'fail-cmd' },
|
||||
},
|
||||
undefined,
|
||||
toolRegistry,
|
||||
);
|
||||
expect(toolRegistry.getAllTools()).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -161,7 +161,11 @@ export class ToolRegistry {
|
|||
}
|
||||
}
|
||||
// discover tools using MCP servers, if configured
|
||||
await discoverMcpTools(this.config);
|
||||
await discoverMcpTools(
|
||||
this.config.getMcpServers() ?? {},
|
||||
this.config.getMcpServerCommand(),
|
||||
this,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in New Issue