From 5008aea90d4ea7ac6bb5872f3702f3c7a7878ed0 Mon Sep 17 00:00:00 2001 From: Tommaso Sciortino Date: Mon, 14 Jul 2025 11:19:33 -0700 Subject: [PATCH] Refactor MCP code for reuse and testing (#3880) --- packages/core/src/tools/mcp-client.test.ts | 1222 ++++------------- packages/core/src/tools/mcp-client.ts | 449 +++--- packages/core/src/tools/tool-registry.test.ts | 109 +- 3 files changed, 613 insertions(+), 1167 deletions(-) diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index df4d71ef..353b4f05 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -4,950 +4,304 @@ * SPDX-License-Identifier: Apache-2.0 */ -/* eslint-disable @typescript-eslint/no-explicit-any */ -import { - describe, - it, - expect, - vi, - beforeEach, - afterEach, - Mocked, -} from 'vitest'; -import { discoverMcpTools } from './mcp-client.js'; -import { sanitizeParameters } from './tool-registry.js'; -import { Schema, Type } from '@google/genai'; -import { Config, MCPServerConfig } from '../config/config.js'; -import { DiscoveredMCPTool } from './mcp-tool.js'; -import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; -import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; +import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; -import { parse, ParseEntry } from 'shell-quote'; +import { + populateMcpServerCommand, + createTransport, + generateValidName, + isEnabled, + discoverTools, +} from './mcp-client.js'; +import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; +import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; +import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; +import * as GenAiLib from '@google/genai'; -// Mock dependencies -vi.mock('shell-quote'); - -vi.mock('@modelcontextprotocol/sdk/client/index.js', () => { - const MockedClient = vi.fn(); - MockedClient.prototype.connect = vi.fn(); - MockedClient.prototype.listTools = vi.fn(); - // Ensure instances have an onerror property that can be spied on or assigned to - MockedClient.mockImplementation(() => ({ - connect: MockedClient.prototype.connect, - listTools: MockedClient.prototype.listTools, - onerror: vi.fn(), // Each instance gets its own onerror mock - })); - return { Client: MockedClient }; -}); - -// Define a global mock for stderr.on that can be cleared and checked -const mockGlobalStdioStderrOn = vi.fn(); - -vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => { - // This is the constructor for StdioClientTransport - const MockedStdioTransport = vi.fn().mockImplementation(function ( - this: any, - options: any, - ) { - // Always return a new object with a fresh reference to the global mock for .on - this.options = options; - this.stderr = { on: mockGlobalStdioStderrOn }; - this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method - return this; - }); - return { StdioClientTransport: MockedStdioTransport }; -}); - -vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => { - const MockedSSETransport = vi.fn().mockImplementation(function (this: any) { - this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method - return this; - }); - return { SSEClientTransport: MockedSSETransport }; -}); - -vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => { - const MockedStreamableHTTPTransport = vi.fn().mockImplementation(function ( - this: any, - ) { - this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method - return this; - }); - return { StreamableHTTPClientTransport: MockedStreamableHTTPTransport }; -}); - -const mockToolRegistryInstance = { - registerTool: vi.fn(), - getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array - // Add other methods if they are called by the code under test, with default mocks - getTool: vi.fn(), - getAllTools: vi.fn().mockReturnValue([]), - getFunctionDeclarations: vi.fn().mockReturnValue([]), - discoverTools: vi.fn().mockResolvedValue(undefined), -}; -vi.mock('./tool-registry.js', async (importOriginal) => { - const actual = await importOriginal(); - return { - ...(actual as any), - ToolRegistry: vi.fn(() => mockToolRegistryInstance), - sanitizeParameters: (actual as any).sanitizeParameters, - }; -}); - -describe('discoverMcpTools', () => { - let mockConfig: Mocked; - // Use the instance from the module mock - let mockToolRegistry: typeof mockToolRegistryInstance; - - beforeEach(() => { - // Assign the shared mock instance to the test-scoped variable - mockToolRegistry = mockToolRegistryInstance; - // Reset individual spies on the shared instance before each test - mockToolRegistry.registerTool.mockClear(); - mockToolRegistry.getToolsByServer.mockClear().mockReturnValue([]); // Reset to default - mockToolRegistry.getTool.mockClear().mockReturnValue(undefined); // Default to no existing tool - mockToolRegistry.getAllTools.mockClear().mockReturnValue([]); - mockToolRegistry.getFunctionDeclarations.mockClear().mockReturnValue([]); - mockToolRegistry.discoverTools.mockClear().mockResolvedValue(undefined); - - mockConfig = { - getMcpServers: vi.fn().mockReturnValue({}), - getMcpServerCommand: vi.fn().mockReturnValue(undefined), - // getToolRegistry should now return the same shared mock instance - getToolRegistry: vi.fn(() => mockToolRegistry), - } as any; - - vi.mocked(parse).mockClear(); - vi.mocked(Client).mockClear(); - vi.mocked(Client.prototype.connect) - .mockClear() - .mockResolvedValue(undefined); - vi.mocked(Client.prototype.listTools) - .mockClear() - .mockResolvedValue({ tools: [] }); - - vi.mocked(StdioClientTransport).mockClear(); - // Ensure the StdioClientTransport mock constructor returns an object with a close method - vi.mocked(StdioClientTransport).mockImplementation(function ( - this: any, - options: any, - ) { - this.options = options; - this.stderr = { on: mockGlobalStdioStderrOn }; - this.close = vi.fn().mockResolvedValue(undefined); - return this; - }); - mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach - - vi.mocked(SSEClientTransport).mockClear(); - // Ensure the SSEClientTransport mock constructor returns an object with a close method - vi.mocked(SSEClientTransport).mockImplementation(function (this: any) { - this.close = vi.fn().mockResolvedValue(undefined); - return this; - }); - - vi.mocked(StreamableHTTPClientTransport).mockClear(); - // Ensure the StreamableHTTPClientTransport mock constructor returns an object with a close method - vi.mocked(StreamableHTTPClientTransport).mockImplementation(function ( - this: any, - ) { - this.close = vi.fn().mockResolvedValue(undefined); - return this; - }); - }); +vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); +vi.mock('@modelcontextprotocol/sdk/client/index.js'); +vi.mock('@google/genai'); +describe('mcp-client', () => { afterEach(() => { vi.restoreAllMocks(); }); - it('should do nothing if no MCP servers or command are configured', async () => { - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1); - expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1); - expect(Client).not.toHaveBeenCalled(); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); - }); - - it('should discover tools via mcpServerCommand', async () => { - const commandString = 'my-mcp-server --start'; - const parsedCommand = ['my-mcp-server', '--start'] as ParseEntry[]; - mockConfig.getMcpServerCommand.mockReturnValue(commandString); - vi.mocked(parse).mockReturnValue(parsedCommand); - - const mockTool = { - name: 'tool1', - description: 'desc1', - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - - // PRE-MOCK getToolsByServer for the expected server name - // In this case, listTools fails, so no tools are registered. - // The default mock `mockReturnValue([])` from beforeEach should apply. - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(parse).toHaveBeenCalledWith(commandString, process.env); - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: parsedCommand[0], - args: parsedCommand.slice(1), - env: expect.any(Object), - cwd: undefined, - stderr: 'pipe', - }); - expect(Client.prototype.connect).toHaveBeenCalledTimes(1); - expect(Client.prototype.listTools).toHaveBeenCalledTimes(1); - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.any(DiscoveredMCPTool), - ); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - expect(registeredTool.name).toBe('tool1'); - expect(registeredTool.serverToolName).toBe('tool1'); - }); - - it('should discover tools via mcpServers config (stdio)', async () => { - const serverConfig: MCPServerConfig = { - command: './mcp-stdio', - args: ['arg1'], - }; - mockConfig.getMcpServers.mockReturnValue({ 'stdio-server': serverConfig }); - - const mockTool = { - name: 'tool-stdio', - description: 'desc-stdio', - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - - // PRE-MOCK getToolsByServer for the expected server name - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(StdioClientTransport).toHaveBeenCalledWith({ - command: serverConfig.command, - args: serverConfig.args, - env: expect.any(Object), - cwd: undefined, - stderr: 'pipe', - }); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.any(DiscoveredMCPTool), - ); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - expect(registeredTool.name).toBe('tool-stdio'); - }); - - it('should discover tools via mcpServers config (sse)', async () => { - const serverConfig: MCPServerConfig = { url: 'http://localhost:1234/sse' }; - mockConfig.getMcpServers.mockReturnValue({ 'sse-server': serverConfig }); - - const mockTool = { - name: 'tool-sse', - description: 'desc-sse', - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - - // PRE-MOCK getToolsByServer for the expected server name - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(SSEClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.url!), - {}, - ); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.any(DiscoveredMCPTool), - ); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - expect(registeredTool.name).toBe('tool-sse'); - }); - - describe('SseClientTransport headers', () => { - const setupSseTest = async (headers?: Record) => { - const serverConfig: MCPServerConfig = { - url: 'http://localhost:1234/sse', - ...(headers && { headers }), - }; - const serverName = headers - ? 'sse-server-with-headers' - : 'sse-server-no-headers'; - const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers'; - - mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig }); - - const mockTool = { - name: toolName, - description: `desc-${toolName}`, - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - return { serverConfig }; - }; - - it('should pass headers when provided', async () => { - const headers = { - Authorization: 'Bearer test-token', - 'X-Custom-Header': 'custom-value', - }; - const { serverConfig } = await setupSseTest(headers); - - expect(SSEClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.url!), - { requestInit: { headers } }, - ); - }); - - it('should work without headers (backwards compatibility)', async () => { - const { serverConfig } = await setupSseTest(); - - expect(SSEClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.url!), - {}, - ); - }); - - it('should pass oauth token when provided', async () => { - const headers = { - Authorization: 'Bearer test-token', - }; - const { serverConfig } = await setupSseTest(headers); - - expect(SSEClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.url!), - { requestInit: { headers } }, - ); - }); - }); - - it('should discover tools via mcpServers config (streamable http)', async () => { - const serverConfig: MCPServerConfig = { - httpUrl: 'http://localhost:3000/mcp', - }; - mockConfig.getMcpServers.mockReturnValue({ 'http-server': serverConfig }); - - const mockTool = { - name: 'tool-http', - description: 'desc-http', - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.httpUrl!), - {}, - ); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.any(DiscoveredMCPTool), - ); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - expect(registeredTool.name).toBe('tool-http'); - }); - - describe('StreamableHTTPClientTransport headers', () => { - const setupHttpTest = async (headers?: Record) => { - const serverConfig: MCPServerConfig = { - httpUrl: 'http://localhost:3000/mcp', - ...(headers && { headers }), - }; - const serverName = headers - ? 'http-server-with-headers' - : 'http-server-no-headers'; - const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers'; - - mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig }); - - const mockTool = { - name: toolName, - description: `desc-${toolName}`, - inputSchema: { type: 'object' as const, properties: {} }, - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - return { serverConfig }; - }; - - it('should pass headers when provided', async () => { - const headers = { - Authorization: 'Bearer test-token', - 'X-Custom-Header': 'custom-value', - }; - const { serverConfig } = await setupHttpTest(headers); - - expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.httpUrl!), - { requestInit: { headers } }, - ); - }); - - it('should work without headers (backwards compatibility)', async () => { - const { serverConfig } = await setupHttpTest(); - - expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.httpUrl!), - {}, - ); - }); - - it('should pass oauth token when provided', async () => { - const headers = { - Authorization: 'Bearer test-token', - }; - const { serverConfig } = await setupHttpTest(headers); - - expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( - new URL(serverConfig.httpUrl!), - { requestInit: { headers } }, - ); - }); - }); - - it('should prefix tool names if multiple MCP servers are configured', async () => { - const serverConfig1: MCPServerConfig = { command: './mcp1' }; - const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' }; - mockConfig.getMcpServers.mockReturnValue({ - server1: serverConfig1, - server2: serverConfig2, - }); - - const mockTool1 = { - name: 'toolA', // Same original name - description: 'd1', - inputSchema: { type: 'object' as const, properties: {} }, - }; - const mockTool2 = { - name: 'toolA', // Same original name - description: 'd2', - inputSchema: { type: 'object' as const, properties: {} }, - }; - const mockToolB = { - name: 'toolB', - description: 'dB', - inputSchema: { type: 'object' as const, properties: {} }, - }; - - vi.mocked(Client.prototype.listTools) - .mockResolvedValueOnce({ tools: [mockTool1, mockToolB] }) // Tools for server1 - .mockResolvedValueOnce({ tools: [mockTool2] }); // Tool for server2 (toolA) - - const effectivelyRegisteredTools = new Map(); - - mockToolRegistry.getTool.mockImplementation((toolName: string) => - effectivelyRegisteredTools.get(toolName), - ); - - // Store the original spy implementation if needed, or just let the new one be the behavior. - // The mockToolRegistry.registerTool is already a vi.fn() from mockToolRegistryInstance. - // We are setting its behavior for this test. - mockToolRegistry.registerTool.mockImplementation((toolToRegister: any) => { - // Simulate the actual registration name being stored for getTool to find - effectivelyRegisteredTools.set(toolToRegister.name, toolToRegister); - // If it's the first time toolA is registered (from server1, not prefixed), - // also make it findable by its original name for the prefixing check of server2/toolA. - if ( - toolToRegister.serverName === 'server1' && - toolToRegister.serverToolName === 'toolA' && - toolToRegister.name === 'toolA' - ) { - effectivelyRegisteredTools.set('toolA', toolToRegister); - } - // The spy call count is inherently tracked by mockToolRegistry.registerTool itself. - }); - - // PRE-MOCK getToolsByServer for the expected server names - // This is for the final check in connectAndDiscover to see if any tools were registered *from that server* - mockToolRegistry.getToolsByServer.mockImplementation( - (serverName: string) => { - if (serverName === 'server1') - return [ - expect.objectContaining({ name: 'toolA' }), - expect.objectContaining({ name: 'toolB' }), - ]; - if (serverName === 'server2') - return [expect.objectContaining({ name: 'server2__toolA' })]; - return []; - }, - ); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3); - const registeredArgs = mockToolRegistry.registerTool.mock.calls.map( - (call) => call[0], - ) as DiscoveredMCPTool[]; - - // The order of server processing by Promise.all is not guaranteed. - // One 'toolA' will be unprefixed, the other will be prefixed. - const toolA_from_server1 = registeredArgs.find( - (t) => t.serverToolName === 'toolA' && t.serverName === 'server1', - ); - const toolA_from_server2 = registeredArgs.find( - (t) => t.serverToolName === 'toolA' && t.serverName === 'server2', - ); - const toolB_from_server1 = registeredArgs.find( - (t) => t.serverToolName === 'toolB' && t.serverName === 'server1', - ); - - expect(toolA_from_server1).toBeDefined(); - expect(toolA_from_server2).toBeDefined(); - expect(toolB_from_server1).toBeDefined(); - - expect(toolB_from_server1?.name).toBe('toolB'); // toolB is unique - - // Check that one of toolA is prefixed and the other is not, and the prefixed one is correct. - if (toolA_from_server1?.name === 'toolA') { - expect(toolA_from_server2?.name).toBe('server2__toolA'); - } else { - expect(toolA_from_server1?.name).toBe('server1__toolA'); - expect(toolA_from_server2?.name).toBe('toolA'); - } - }); - - it('should clean schema properties ($schema, additionalProperties)', async () => { - const serverConfig: MCPServerConfig = { command: './mcp-clean' }; - mockConfig.getMcpServers.mockReturnValue({ 'clean-server': serverConfig }); - - const rawSchema = { - type: 'object' as const, - $schema: 'http://json-schema.org/draft-07/schema#', - additionalProperties: true, - properties: { - prop1: { type: 'string', $schema: 'remove-this' }, - prop2: { - type: 'object' as const, - additionalProperties: false, - properties: { nested: { type: 'number' } }, - }, - }, - }; - const mockTool = { - name: 'cleanTool', - description: 'd', - inputSchema: JSON.parse(JSON.stringify(rawSchema)), - }; - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: [mockTool], - }); - // PRE-MOCK getToolsByServer for the expected server name - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); - const registeredTool = mockToolRegistry.registerTool.mock - .calls[0][0] as DiscoveredMCPTool; - const cleanedParams = registeredTool.schema.parameters as any; - - expect(cleanedParams).not.toHaveProperty('$schema'); - expect(cleanedParams).not.toHaveProperty('additionalProperties'); - expect(cleanedParams.properties.prop1).not.toHaveProperty('$schema'); - expect(cleanedParams.properties.prop2).not.toHaveProperty( - 'additionalProperties', - ); - expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty( - '$schema', - ); - expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty( - 'additionalProperties', - ); - }); - - it('should handle error if mcpServerCommand parsing fails', async () => { - const commandString = 'my-mcp-server "unterminated quote'; - mockConfig.getMcpServerCommand.mockReturnValue(commandString); - vi.mocked(parse).mockImplementation(() => { - throw new Error('Parsing failed'); - }); - vi.spyOn(console, 'error').mockImplementation(() => {}); - - await expect( - discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ), - ).rejects.toThrow('Parsing failed'); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); - expect(console.error).not.toHaveBeenCalled(); - }); - - it('should log error and skip server if config is invalid (missing url and command)', async () => { - mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any }); - vi.spyOn(console, 'error').mockImplementation(() => {}); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(console.error).toHaveBeenCalledWith( - expect.stringContaining( - "MCP server 'bad-server' has invalid configuration", - ), - ); - // Client constructor should not be called if config is invalid before instantiation - expect(Client).not.toHaveBeenCalled(); - }); - - it('should log error and skip server if mcpClient.connect fails', async () => { - const serverConfig: MCPServerConfig = { command: './mcp-fail-connect' }; - mockConfig.getMcpServers.mockReturnValue({ - 'fail-connect-server': serverConfig, - }); - vi.mocked(Client.prototype.connect).mockRejectedValue( - new Error('Connection refused'), - ); - vi.spyOn(console, 'error').mockImplementation(() => {}); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(console.error).toHaveBeenCalledWith( - expect.stringContaining( - "failed to start or connect to MCP server 'fail-connect-server'", - ), - ); - expect(Client.prototype.listTools).not.toHaveBeenCalled(); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); - }); - - it('should log error and skip server if mcpClient.listTools fails', async () => { - const serverConfig: MCPServerConfig = { command: './mcp-fail-list' }; - mockConfig.getMcpServers.mockReturnValue({ - 'fail-list-server': serverConfig, - }); - vi.mocked(Client.prototype.listTools).mockRejectedValue( - new Error('ListTools error'), - ); - vi.spyOn(console, 'error').mockImplementation(() => {}); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(console.error).toHaveBeenCalledWith( - expect.stringContaining( - "Failed to list or register tools for MCP server 'fail-list-server'", - ), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); - }); - - it('should assign mcpClient.onerror handler', async () => { - const serverConfig: MCPServerConfig = { command: './mcp-onerror' }; - mockConfig.getMcpServers.mockReturnValue({ - 'onerror-server': serverConfig, - }); - // PRE-MOCK getToolsByServer for the expected server name - mockToolRegistry.getToolsByServer.mockReturnValueOnce([ - expect.any(DiscoveredMCPTool), - ]); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - const clientInstances = vi.mocked(Client).mock.results; - expect(clientInstances.length).toBeGreaterThan(0); - const lastClientInstance = - clientInstances[clientInstances.length - 1]?.value; - expect(lastClientInstance?.onerror).toEqual(expect.any(Function)); - }); - - describe('Tool Filtering', () => { - const mockTools = [ - { - name: 'toolA', - description: 'descA', - inputSchema: { type: 'object' as const, properties: {} }, - }, - { - name: 'toolB', - description: 'descB', - inputSchema: { type: 'object' as const, properties: {} }, - }, - { - name: 'toolC', - description: 'descC', - inputSchema: { type: 'object' as const, properties: {} }, - }, - ]; - - beforeEach(() => { - vi.mocked(Client.prototype.listTools).mockResolvedValue({ - tools: mockTools, - }); - mockToolRegistry.getToolsByServer.mockReturnValue([ - expect.any(DiscoveredMCPTool), - ]); - }); - - it('should only include specified tools with includeTools', async () => { - const serverConfig: MCPServerConfig = { - command: './mcp-include', - includeTools: ['toolA', 'toolC'], - }; - mockConfig.getMcpServers.mockReturnValue({ - 'include-server': serverConfig, - }); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolA' }), - ); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolC' }), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolB' }), - ); - }); - - it('should exclude specified tools with excludeTools', async () => { - const serverConfig: MCPServerConfig = { - command: './mcp-exclude', - excludeTools: ['toolB'], - }; - mockConfig.getMcpServers.mockReturnValue({ - 'exclude-server': serverConfig, - }); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolA' }), - ); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolC' }), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolB' }), - ); - }); - - it('should handle both includeTools and excludeTools', async () => { - const serverConfig: MCPServerConfig = { - command: './mcp-both', - includeTools: ['toolA', 'toolB'], - excludeTools: ['toolB'], - }; - mockConfig.getMcpServers.mockReturnValue({ 'both-server': serverConfig }); - - await discoverMcpTools( - mockConfig.getMcpServers() ?? {}, - mockConfig.getMcpServerCommand(), - mockToolRegistry as any, - false, - ); - - expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); - expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolA' }), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolB' }), - ); - expect(mockToolRegistry.registerTool).not.toHaveBeenCalledWith( - expect.objectContaining({ serverToolName: 'toolC' }), - ); - }); - }); -}); - -describe('sanitizeParameters', () => { - it('should do nothing for an undefined schema', () => { - const schema = undefined; - sanitizeParameters(schema); - }); - - it('should remove default when anyOf is present', () => { - const schema: Schema = { - anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }], - default: 'hello', - }; - sanitizeParameters(schema); - expect(schema.default).toBeUndefined(); - }); - - it('should recursively sanitize items in anyOf', () => { - const schema: Schema = { - anyOf: [ - { - anyOf: [{ type: Type.STRING }], - default: 'world', - }, - { type: Type.NUMBER }, - ], - }; - sanitizeParameters(schema); - expect(schema.anyOf![0].default).toBeUndefined(); - }); - - it('should recursively sanitize items in items', () => { - const schema: Schema = { - items: { - anyOf: [{ type: Type.STRING }], - default: 'world', - }, - }; - sanitizeParameters(schema); - expect(schema.items!.default).toBeUndefined(); - }); - - it('should recursively sanitize items in properties', () => { - const schema: Schema = { - properties: { - prop1: { - anyOf: [{ type: Type.STRING }], - default: 'world', - }, - }, - }; - sanitizeParameters(schema); - expect(schema.properties!.prop1.default).toBeUndefined(); - }); - - it('should handle complex nested schemas', () => { - const schema: Schema = { - properties: { - prop1: { - items: { - anyOf: [{ type: Type.STRING }], - default: 'world', - }, - }, - prop2: { - anyOf: [ + describe('discoverTools', () => { + it('should discover tools', async () => { + const mockedClient = {} as unknown as ClientLib.Client; + const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ + tool: () => ({ + functionDeclarations: [ { - properties: { - nestedProp: { - anyOf: [{ type: Type.NUMBER }], - default: 123, - }, - }, + name: 'testFunction', }, ], + }), + } as unknown as GenAiLib.CallableTool); + + const tools = await discoverTools('test-server', {}, mockedClient); + + expect(tools.length).toBe(1); + expect(mockedMcpToTool).toHaveBeenCalledOnce(); + }); + }); + + describe('appendMcpServerCommand', () => { + it('should do nothing if no MCP servers or command are configured', () => { + const out = populateMcpServerCommand({}, undefined); + expect(out).toEqual({}); + }); + + it('should discover tools via mcpServerCommand', () => { + const commandString = 'command --arg1 value1'; + const out = populateMcpServerCommand({}, commandString); + expect(out).toEqual({ + mcp: { + command: 'command', + args: ['--arg1', 'value1'], }, - }, - }; - sanitizeParameters(schema); - expect(schema.properties!.prop1.items!.default).toBeUndefined(); - const nestedProp = - schema.properties!.prop2.anyOf![0].properties!.nestedProp; - expect(nestedProp?.default).toBeUndefined(); + }); + }); + + it('should handle error if mcpServerCommand parsing fails', () => { + expect(() => populateMcpServerCommand({}, 'derp && herp')).toThrowError(); + }); + }); + + describe('createTransport', () => { + const originalEnv = process.env; + + beforeEach(() => { + vi.resetModules(); + process.env = {}; + }); + + afterEach(() => { + process.env = originalEnv; + }); + + describe('should connect via httpUrl', () => { + it('without headers', async () => { + const transport = createTransport( + 'test-server', + { + httpUrl: 'http://test-server', + }, + false, + ); + + expect(transport).toEqual( + new StreamableHTTPClientTransport(new URL('http://test-server'), {}), + ); + }); + + it('with headers', async () => { + const transport = createTransport( + 'test-server', + { + httpUrl: 'http://test-server', + headers: { Authorization: 'derp' }, + }, + false, + ); + + expect(transport).toEqual( + new StreamableHTTPClientTransport(new URL('http://test-server'), { + requestInit: { + headers: { Authorization: 'derp' }, + }, + }), + ); + }); + }); + + describe('should connect via url', () => { + it('without headers', async () => { + const transport = createTransport( + 'test-server', + { + url: 'http://test-server', + }, + false, + ); + expect(transport).toEqual( + new SSEClientTransport(new URL('http://test-server'), {}), + ); + }); + + it('with headers', async () => { + const transport = createTransport( + 'test-server', + { + url: 'http://test-server', + headers: { Authorization: 'derp' }, + }, + false, + ); + + expect(transport).toEqual( + new SSEClientTransport(new URL('http://test-server'), { + requestInit: { + headers: { Authorization: 'derp' }, + }, + }), + ); + }); + }); + + it('should connect via command', () => { + const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport); + + createTransport( + 'test-server', + { + command: 'test-command', + args: ['--foo', 'bar'], + env: { FOO: 'bar' }, + cwd: 'test/cwd', + }, + false, + ); + + expect(mockedTransport).toHaveBeenCalledWith({ + command: 'test-command', + args: ['--foo', 'bar'], + cwd: 'test/cwd', + env: { FOO: 'bar' }, + stderr: 'pipe', + }); + }); + }); + describe('generateValidName', () => { + it('should return a valid name for a simple function', () => { + const funcDecl = { name: 'myFunction' }; + const serverName = 'myServer'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('myServer__myFunction'); + }); + + it('should prepend the server name', () => { + const funcDecl = { name: 'anotherFunction' }; + const serverName = 'production-server'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('production-server__anotherFunction'); + }); + + it('should replace invalid characters with underscores', () => { + const funcDecl = { name: 'invalid-name with spaces' }; + const serverName = 'test_server'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('test_server__invalid-name_with_spaces'); + }); + + it('should truncate long names', () => { + const funcDecl = { + name: 'a_very_long_function_name_that_will_definitely_exceed_the_limit', + }; + const serverName = 'a_long_server_name'; + const result = generateValidName(funcDecl, serverName); + expect(result.length).toBe(63); + expect(result).toBe( + 'a_long_server_name__a_very_l___will_definitely_exceed_the_limit', + ); + }); + + it('should handle names with only invalid characters', () => { + const funcDecl = { name: '!@#$%^&*()' }; + const serverName = 'special-chars'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('special-chars____________'); + }); + + it('should handle names that are already valid', () => { + const funcDecl = { name: 'already_valid' }; + const serverName = 'validator'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('validator__already_valid'); + }); + + it('should handle names with leading/trailing invalid characters', () => { + const funcDecl = { name: '-_invalid-_' }; + const serverName = 'trim-test'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe('trim-test__-_invalid-_'); + }); + + it('should handle names that are exactly 63 characters long', () => { + const longName = 'a'.repeat(45); + const funcDecl = { name: longName }; + const serverName = 'server'; + const result = generateValidName(funcDecl, serverName); + expect(result).toBe(`server__${longName}`); + expect(result.length).toBe(53); + }); + + it('should handle names that are exactly 64 characters long', () => { + const longName = 'a'.repeat(55); + const funcDecl = { name: longName }; + const serverName = 'server'; + const result = generateValidName(funcDecl, serverName); + expect(result.length).toBe(63); + expect(result).toBe( + 'server__aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + ); + }); + + it('should handle names that are longer than 64 characters', () => { + const longName = 'a'.repeat(100); + const funcDecl = { name: longName }; + const serverName = 'long-server'; + const result = generateValidName(funcDecl, serverName); + expect(result.length).toBe(63); + expect(result).toBe( + 'long-server__aaaaaaaaaaaaaaa___aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + ); + }); + }); + describe('isEnabled', () => { + const funcDecl = { name: 'myTool' }; + const serverName = 'myServer'; + + it('should return true if no include or exclude lists are provided', () => { + const mcpServerConfig = {}; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); + }); + + it('should return false if the tool is in the exclude list', () => { + const mcpServerConfig = { excludeTools: ['myTool'] }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); + }); + + it('should return true if the tool is in the include list', () => { + const mcpServerConfig = { includeTools: ['myTool'] }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); + }); + + it('should return true if the tool is in the include list with parentheses', () => { + const mcpServerConfig = { includeTools: ['myTool()'] }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); + }); + + it('should return false if the include list exists but does not contain the tool', () => { + const mcpServerConfig = { includeTools: ['anotherTool'] }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); + }); + + it('should return false if the tool is in both the include and exclude lists', () => { + const mcpServerConfig = { + includeTools: ['myTool'], + excludeTools: ['myTool'], + }; + expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); + }); + + it('should return false if the function declaration has no name', () => { + const namelessFuncDecl = {}; + const mcpServerConfig = {}; + expect(isEnabled(namelessFuncDecl, serverName, mcpServerConfig)).toBe( + false, + ); + }); }); }); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 6edfbac8..eb82190b 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -5,6 +5,7 @@ */ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { SSEClientTransport, @@ -17,7 +18,7 @@ import { import { parse } from 'shell-quote'; import { MCPServerConfig } from '../config/config.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; -import { Type, mcpToTool } from '@google/genai'; +import { FunctionDeclaration, Type, mcpToTool } from '@google/genai'; import { sanitizeParameters, ToolRegistry } from './tool-registry.js'; export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes @@ -123,28 +124,25 @@ export function getMCPDiscoveryState(): MCPDiscoveryState { return mcpDiscoveryState; } +/** + * Discovers tools from all configured MCP servers and registers them with the tool registry. + * It orchestrates the connection and discovery process for each server defined in the + * configuration, as well as any server specified via a command-line argument. + * + * @param mcpServers A record of named MCP server configurations. + * @param mcpServerCommand An optional command string for a dynamically specified MCP server. + * @param toolRegistry The central registry where discovered tools will be registered. + * @returns A promise that resolves when the discovery process has been attempted for all servers. + */ export async function discoverMcpTools( mcpServers: Record, mcpServerCommand: string | undefined, toolRegistry: ToolRegistry, debugMode: boolean, ): Promise { - // Set discovery state to in progress mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; - try { - if (mcpServerCommand) { - const cmd = mcpServerCommand; - const args = parse(cmd, process.env) as string[]; - if (args.some((arg) => typeof arg !== 'string')) { - throw new Error('failed to parse mcpServerCommand: ' + cmd); - } - // use generic server name 'mcp' - mcpServers['mcp'] = { - command: args[0], - args: args.slice(1), - }; - } + mcpServers = populateMcpServerCommand(mcpServers, mcpServerCommand); const discoveryPromises = Object.entries(mcpServers).map( ([mcpServerName, mcpServerConfig]) => @@ -156,16 +154,31 @@ export async function discoverMcpTools( ), ); await Promise.all(discoveryPromises); - - // Mark discovery as completed + } finally { mcpDiscoveryState = MCPDiscoveryState.COMPLETED; - } catch (error) { - // Still mark as completed even with errors - mcpDiscoveryState = MCPDiscoveryState.COMPLETED; - throw error; } } +/** Visible for Testing */ +export function populateMcpServerCommand( + mcpServers: Record, + mcpServerCommand: string | undefined, +): Record { + if (mcpServerCommand) { + const cmd = mcpServerCommand; + const args = parse(cmd, process.env) as string[]; + if (args.some((arg) => typeof arg !== 'string')) { + throw new Error('failed to parse mcpServerCommand: ' + cmd); + } + // use generic server name 'mcp' + mcpServers['mcp'] = { + command: args[0], + args: args.slice(1), + }; + } + return mcpServers; +} + /** * Connects to an MCP server and discovers available tools, registering them with the tool registry. * This function handles the complete lifecycle of connecting to a server, discovering tools, @@ -176,71 +189,117 @@ export async function discoverMcpTools( * @param toolRegistry The registry to register discovered tools with * @returns Promise that resolves when discovery is complete */ -async function connectAndDiscover( +export async function connectAndDiscover( mcpServerName: string, mcpServerConfig: MCPServerConfig, toolRegistry: ToolRegistry, debugMode: boolean, ): Promise { - // Initialize the server status as connecting updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); - let transport; - if (mcpServerConfig.httpUrl) { - const transportOptions: StreamableHTTPClientTransportOptions = {}; + try { + const mcpClient = await connectToMcpServer( + mcpServerName, + mcpServerConfig, + debugMode, + ); + try { + updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED); - if (mcpServerConfig.headers) { - transportOptions.requestInit = { - headers: mcpServerConfig.headers, + mcpClient.onerror = (error) => { + console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); + updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); }; - } - transport = new StreamableHTTPClientTransport( - new URL(mcpServerConfig.httpUrl), - transportOptions, - ); - } else if (mcpServerConfig.url) { - const transportOptions: SSEClientTransportOptions = {}; - if (mcpServerConfig.headers) { - transportOptions.requestInit = { - headers: mcpServerConfig.headers, - }; + const tools = await discoverTools( + mcpServerName, + mcpServerConfig, + mcpClient, + ); + for (const tool of tools) { + toolRegistry.registerTool(tool); + } + } catch (error) { + mcpClient.close(); + throw error; } - transport = new SSEClientTransport( - new URL(mcpServerConfig.url), - transportOptions, - ); - } else if (mcpServerConfig.command) { - transport = new StdioClientTransport({ - command: mcpServerConfig.command, - args: mcpServerConfig.args || [], - env: { - ...process.env, - ...(mcpServerConfig.env || {}), - } as Record, - cwd: mcpServerConfig.cwd, - stderr: 'pipe', - }); - } else { - console.error( - `MCP server '${mcpServerName}' has invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio). Skipping.`, - ); - // Update status to disconnected + } catch (error) { + console.error(`Error connecting to MCP server '${mcpServerName}':`, error); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - return; } +} - if ( - debugMode && - transport instanceof StdioClientTransport && - transport.stderr - ) { - transport.stderr.on('data', (data) => { - const stderrStr = data.toString().trim(); - console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr); - }); +/** + * Discovers and sanitizes tools from a connected MCP client. + * It retrieves function declarations from the client, filters out disabled tools, + * generates valid names for them, and wraps them in `DiscoveredMCPTool` instances. + * + * @param mcpServerName The name of the MCP server. + * @param mcpServerConfig The configuration for the MCP server. + * @param mcpClient The active MCP client instance. + * @returns A promise that resolves to an array of discovered and enabled tools. + * @throws An error if no enabled tools are found or if the server provides invalid function declarations. + */ +export async function discoverTools( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + mcpClient: Client, +): Promise { + try { + const mcpCallableTool = mcpToTool(mcpClient); + const tool = await mcpCallableTool.tool(); + + if (!Array.isArray(tool.functionDeclarations)) { + throw new Error(`Server did not return valid function declarations.`); + } + + const discoveredTools: DiscoveredMCPTool[] = []; + for (const funcDecl of tool.functionDeclarations) { + if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) { + continue; + } + + const toolNameForModel = generateValidName(funcDecl, mcpServerName); + + sanitizeParameters(funcDecl.parameters); + + discoveredTools.push( + new DiscoveredMCPTool( + mcpCallableTool, + mcpServerName, + toolNameForModel, + funcDecl.description ?? '', + funcDecl.parameters ?? { type: Type.OBJECT, properties: {} }, + funcDecl.name!, + mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + mcpServerConfig.trust, + ), + ); + } + if (discoveredTools.length === 0) { + throw Error('No enabled tools found'); + } + return discoveredTools; + } catch (error) { + throw new Error(`Error discovering tools: ${error}`); } +} +/** + * Creates and connects an MCP client to a server based on the provided configuration. + * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and + * establishes a connection. It also applies a patch to handle request timeouts. + * + * @param mcpServerName The name of the MCP server, used for logging and identification. + * @param mcpServerConfig The configuration specifying how to connect to the server. + * @returns A promise that resolves to a connected MCP `Client` instance. + * @throws An error if the connection fails or the configuration is invalid. + */ +export async function connectToMcpServer( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + debugMode: boolean, +): Promise { const mcpClient = new Client({ name: 'gemini-cli-mcp-client', version: '0.0.1', @@ -259,11 +318,20 @@ async function connectAndDiscover( } try { - await mcpClient.connect(transport, { - timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, - }); - // Connection successful - updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED); + const transport = createTransport( + mcpServerName, + mcpServerConfig, + debugMode, + ); + try { + await mcpClient.connect(transport, { + timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); + return mcpClient; + } catch (error) { + await transport.close(); + throw error; + } } catch (error) { // Create a safe config object that excludes sensitive information const safeConfig = { @@ -282,131 +350,110 @@ async function connectAndDiscover( if (process.env.SANDBOX) { errorString += `\nMake sure it is available in the sandbox`; } - console.error(errorString); - // Update status to disconnected - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - return; - } - - mcpClient.onerror = (error) => { - console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); - // Update status to disconnected on error - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - }; - - try { - const mcpCallableTool = mcpToTool(mcpClient); - const tool = await mcpCallableTool.tool(); - - if (!tool || !Array.isArray(tool.functionDeclarations)) { - console.error( - `MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`, - ); - if ( - transport instanceof StdioClientTransport || - transport instanceof SSEClientTransport || - transport instanceof StreamableHTTPClientTransport - ) { - await transport.close(); - } - // Update status to disconnected - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - return; - } - - for (const funcDecl of tool.functionDeclarations) { - if (!funcDecl.name) { - console.warn( - `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`, - ); - continue; - } - - const { includeTools, excludeTools } = mcpServerConfig; - const toolName = funcDecl.name; - - let isEnabled = false; - if (includeTools === undefined) { - isEnabled = true; - } else { - isEnabled = includeTools.some( - (tool) => tool === toolName || tool.startsWith(`${toolName}(`), - ); - } - - if (excludeTools?.includes(toolName)) { - isEnabled = false; - } - - if (!isEnabled) { - continue; - } - - let toolNameForModel = funcDecl.name; - - // Replace invalid characters (based on 400 error message from Gemini API) with underscores - toolNameForModel = toolNameForModel.replace(/[^a-zA-Z0-9_.-]/g, '_'); - - const existingTool = toolRegistry.getTool(toolNameForModel); - if (existingTool) { - toolNameForModel = mcpServerName + '__' + toolNameForModel; - } - - // If longer than 63 characters, replace middle with '___' - // (Gemini API says max length 64, but actual limit seems to be 63) - if (toolNameForModel.length > 63) { - toolNameForModel = - toolNameForModel.slice(0, 28) + '___' + toolNameForModel.slice(-32); - } - - sanitizeParameters(funcDecl.parameters); - - toolRegistry.registerTool( - new DiscoveredMCPTool( - mcpCallableTool, - mcpServerName, - toolNameForModel, - funcDecl.description ?? '', - funcDecl.parameters ?? { type: Type.OBJECT, properties: {} }, - funcDecl.name, - mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, - mcpServerConfig.trust, - ), - ); - } - } catch (error) { - console.error( - `Failed to list or register tools for MCP server '${mcpServerName}': ${error}`, - ); - // Ensure transport is cleaned up on error too - if ( - transport instanceof StdioClientTransport || - transport instanceof SSEClientTransport || - transport instanceof StreamableHTTPClientTransport - ) { - await transport.close(); - } - // Update status to disconnected - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - } - - // If no tools were registered from this MCP server, the following 'if' block - // will close the connection. This is done to conserve resources and prevent - // an orphaned connection to a server that isn't providing any usable - // functionality. Connections to servers that did provide tools are kept - // open, as those tools will require the connection to function. - if (toolRegistry.getToolsByServer(mcpServerName).length === 0) { - console.log( - `No tools registered from MCP server '${mcpServerName}'. Closing connection.`, - ); - if ( - transport instanceof StdioClientTransport || - transport instanceof SSEClientTransport || - transport instanceof StreamableHTTPClientTransport - ) { - await transport.close(); - // Update status to disconnected - updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); - } + throw new Error(errorString); } } + +/** Visible for Testing */ +export function createTransport( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + debugMode: boolean, +): Transport { + if (mcpServerConfig.httpUrl) { + const transportOptions: StreamableHTTPClientTransportOptions = {}; + if (mcpServerConfig.headers) { + transportOptions.requestInit = { + headers: mcpServerConfig.headers, + }; + } + return new StreamableHTTPClientTransport( + new URL(mcpServerConfig.httpUrl), + transportOptions, + ); + } + + if (mcpServerConfig.url) { + const transportOptions: SSEClientTransportOptions = {}; + if (mcpServerConfig.headers) { + transportOptions.requestInit = { + headers: mcpServerConfig.headers, + }; + } + return new SSEClientTransport( + new URL(mcpServerConfig.url), + transportOptions, + ); + } + + if (mcpServerConfig.command) { + const transport = new StdioClientTransport({ + command: mcpServerConfig.command, + args: mcpServerConfig.args || [], + env: { + ...process.env, + ...(mcpServerConfig.env || {}), + } as Record, + cwd: mcpServerConfig.cwd, + stderr: 'pipe', + }); + if (debugMode) { + transport.stderr!.on('data', (data) => { + const stderrStr = data.toString().trim(); + console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr); + }); + } + return transport; + } + + throw new Error( + `Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`, + ); +} + +/** Visible for testing */ +export function generateValidName( + funcDecl: FunctionDeclaration, + mcpServerName: string, +) { + // Replace invalid characters (based on 400 error message from Gemini API) with underscores + let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_'); + + // Prepend MCP server name to avoid conflicts with other tools + validToolname = mcpServerName + '__' + validToolname; + + // If longer than 63 characters, replace middle with '___' + // (Gemini API says max length 64, but actual limit seems to be 63) + if (validToolname.length > 63) { + validToolname = + validToolname.slice(0, 28) + '___' + validToolname.slice(-32); + } + return validToolname; +} + +/** Visible for testing */ +export function isEnabled( + funcDecl: FunctionDeclaration, + mcpServerName: string, + mcpServerConfig: MCPServerConfig, +): boolean { + if (!funcDecl.name) { + console.warn( + `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`, + ); + return false; + } + const { includeTools, excludeTools } = mcpServerConfig; + + // excludeTools takes precedence over includeTools + if (excludeTools && excludeTools.includes(funcDecl.name)) { + return false; + } + + return ( + !includeTools || + includeTools.some( + (tool) => tool === funcDecl.name || tool.startsWith(`${funcDecl.name}(`), + ) + ); +} diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index fba48c17..853f6458 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -326,6 +326,83 @@ describe('ToolRegistry', () => { }); describe('sanitizeParameters', () => { + it('should remove default when anyOf is present', () => { + const schema: Schema = { + anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }], + default: 'hello', + }; + sanitizeParameters(schema); + expect(schema.default).toBeUndefined(); + }); + + it('should recursively sanitize items in anyOf', () => { + const schema: Schema = { + anyOf: [ + { + anyOf: [{ type: Type.STRING }], + default: 'world', + }, + { type: Type.NUMBER }, + ], + }; + sanitizeParameters(schema); + expect(schema.anyOf![0].default).toBeUndefined(); + }); + + it('should recursively sanitize items in items', () => { + const schema: Schema = { + items: { + anyOf: [{ type: Type.STRING }], + default: 'world', + }, + }; + sanitizeParameters(schema); + expect(schema.items!.default).toBeUndefined(); + }); + + it('should recursively sanitize items in properties', () => { + const schema: Schema = { + properties: { + prop1: { + anyOf: [{ type: Type.STRING }], + default: 'world', + }, + }, + }; + sanitizeParameters(schema); + expect(schema.properties!.prop1.default).toBeUndefined(); + }); + + it('should handle complex nested schemas', () => { + const schema: Schema = { + properties: { + prop1: { + items: { + anyOf: [{ type: Type.STRING }], + default: 'world', + }, + }, + prop2: { + anyOf: [ + { + properties: { + nestedProp: { + anyOf: [{ type: Type.NUMBER }], + default: 123, + }, + }, + }, + ], + }, + }, + }; + sanitizeParameters(schema); + expect(schema.properties!.prop1.items!.default).toBeUndefined(); + const nestedProp = + schema.properties!.prop2.anyOf![0].properties!.nestedProp; + expect(nestedProp?.default).toBeUndefined(); + }); + it('should remove unsupported format from a simple string property', () => { const schema: Schema = { type: Type.OBJECT, @@ -356,25 +433,6 @@ describe('sanitizeParameters', () => { expect(schema).toEqual(originalSchema); }); - it('should handle nested objects recursively', () => { - const schema: Schema = { - type: Type.OBJECT, - properties: { - user: { - type: Type.OBJECT, - properties: { - email: { type: Type.STRING, format: 'email' }, - }, - }, - }, - }; - sanitizeParameters(schema); - expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty( - 'format', - undefined, - ); - }); - it('should handle arrays of objects', () => { const schema: Schema = { type: Type.OBJECT, @@ -414,19 +472,6 @@ describe('sanitizeParameters', () => { expect(() => sanitizeParameters(undefined)).not.toThrow(); }); - it('should handle cyclic schemas without crashing', () => { - const schema: any = { - type: Type.OBJECT, - properties: { - name: { type: Type.STRING, format: 'hostname' }, - }, - }; - schema.properties.self = schema; - - expect(() => sanitizeParameters(schema)).not.toThrow(); - expect(schema.properties.name).toHaveProperty('format', undefined); - }); - it('should handle complex nested schemas with cycles', () => { const userNode: any = { type: Type.OBJECT,