diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 3bb5b85e..f2404bb0 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -84,6 +84,8 @@ export class MCPServerConfig { readonly trust?: boolean, // Metadata readonly description?: string, + readonly includeTools?: string[], + readonly excludeTools?: string[], ) {} } diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 09f7951b..1413a4f8 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -668,6 +668,115 @@ describe('discoverMcpTools', () => { clientInstances[clientInstances.length - 1]?.value; expect(lastClientInstance?.onerror).toEqual(expect.any(Function)); }); + + describe('Tool Filtering', () => { + const mockTools = [ + { + name: 'toolA', + description: 'descA', + inputSchema: { type: 'object' as const, properties: {} }, + }, + { + name: 'toolB', + description: 'descB', + inputSchema: { type: 'object' as const, properties: {} }, + }, + { + name: 'toolC', + description: 'descC', + inputSchema: { type: 'object' as const, properties: {} }, + }, + ]; + + beforeEach(() => { + vi.mocked(Client.prototype.listTools).mockResolvedValue({ + tools: mockTools, + }); + mockToolRegistry.getToolsByServer.mockReturnValue([ + expect.any(DiscoveredMCPTool), + ]); + }); + + it('should only include specified tools with includeTools', async () => { + const serverConfig: MCPServerConfig = { + command: './mcp-include', + includeTools: ['toolA', 'toolC'], + }; + mockConfig.getMcpServers.mockReturnValue({ + 'include-server': serverConfig, + }); + + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); + + expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolA' }), + ); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolC' }), + ); + expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolB' }), + ); + }); + + it('should exclude specified tools with excludeTools', async () => { + const serverConfig: MCPServerConfig = { + command: './mcp-exclude', + excludeTools: ['toolB'], + }; + mockConfig.getMcpServers.mockReturnValue({ + 'exclude-server': serverConfig, + }); + + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); + + expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolA' }), + ); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolC' }), + ); + expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolB' }), + ); + }); + + it('should handle both includeTools and excludeTools', async () => { + const serverConfig: MCPServerConfig = { + command: './mcp-both', + includeTools: ['toolA', 'toolB'], + excludeTools: ['toolB'], + }; + mockConfig.getMcpServers.mockReturnValue({ 'both-server': serverConfig }); + + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); + + expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolA' }), + ); + expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolB' }), + ); + expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( + expect.objectContaining({ serverToolName: 'toolC' }), + ); + }); + }); }); describe('sanitizeParameters', () => { diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index bb92ab05..6dca8cea 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -305,6 +305,26 @@ async function connectAndDiscover( continue; } + const { includeTools, excludeTools } = mcpServerConfig; + const toolName = funcDecl.name; + + let isEnabled = false; + if (includeTools === undefined) { + isEnabled = true; + } else { + isEnabled = includeTools.some( + (tool) => tool === toolName || tool.startsWith(`${toolName}(`), + ); + } + + if (excludeTools?.includes(toolName)) { + isEnabled = false; + } + + if (!isEnabled) { + continue; + } + let toolNameForModel = funcDecl.name; // Replace invalid characters (based on 400 error message from Gemini API) with underscores