From d74c0f581bf5ba0c74a7b7874f6638db6897f907 Mon Sep 17 00:00:00 2001 From: Taylor Mullen Date: Wed, 28 May 2025 00:43:23 -0700 Subject: [PATCH] refactor: Extract MCP discovery from ToolRegistry - Moves MCP tool discovery logic from ToolRegistry into a new, dedicated MCP client (mcp-client.ts and mcp-tool.ts). - Updates ToolRegistry to utilize the new MCP client. - Adds comprehensive tests for the new MCP client and its integration with ToolRegistry. Part of https://github.com/google-gemini/gemini-cli/issues/577 --- packages/server/src/tools/mcp-client.test.ts | 371 ++++++++++++++++++ packages/server/src/tools/mcp-client.ts | 138 +++++++ packages/server/src/tools/mcp-tool.test.ts | 161 ++++++++ packages/server/src/tools/mcp-tool.ts | 49 +++ .../server/src/tools/tool-registry.test.ts | 26 +- packages/server/src/tools/tool-registry.ts | 154 +------- 6 files changed, 737 insertions(+), 162 deletions(-) create mode 100644 packages/server/src/tools/mcp-client.test.ts create mode 100644 packages/server/src/tools/mcp-client.ts create mode 100644 packages/server/src/tools/mcp-tool.test.ts create mode 100644 packages/server/src/tools/mcp-tool.ts diff --git a/packages/server/src/tools/mcp-client.test.ts b/packages/server/src/tools/mcp-client.test.ts new file mode 100644 index 00000000..4664669d --- /dev/null +++ b/packages/server/src/tools/mcp-client.test.ts @@ -0,0 +1,371 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + Mocked, +} from 'vitest'; +import { discoverMcpTools } from './mcp-client.js'; +import { Config, MCPServerConfig } from '../config/config.js'; +import { ToolRegistry } from './tool-registry.js'; +import { DiscoveredMCPTool } from './mcp-tool.js'; +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; +import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; +import { parse, ParseEntry } from 'shell-quote'; + +// Mock dependencies +vi.mock('shell-quote'); + +vi.mock('@modelcontextprotocol/sdk/client/index.js', () => { + const MockedClient = vi.fn(); + MockedClient.prototype.connect = vi.fn(); + MockedClient.prototype.listTools = vi.fn(); + // Ensure instances have an onerror property that can be spied on or assigned to + MockedClient.mockImplementation(() => ({ + connect: MockedClient.prototype.connect, + listTools: MockedClient.prototype.listTools, + onerror: vi.fn(), // Each instance gets its own onerror mock + })); + return { Client: MockedClient }; +}); + +// Define a global mock for stderr.on that can be cleared and checked +const mockGlobalStdioStderrOn = vi.fn(); + +vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => { + // This is the constructor for StdioClientTransport + const MockedStdioTransport = vi.fn().mockImplementation(function ( + this: any, + options: any, + ) { + // Always return a new object with a fresh reference to the global mock for .on + this.options = options; + this.stderr = { on: mockGlobalStdioStderrOn }; + return this; + }); + return { StdioClientTransport: MockedStdioTransport }; +}); + +vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => { + const MockedSSETransport = vi.fn(); + return { SSEClientTransport: MockedSSETransport }; +}); + +vi.mock('./tool-registry.js'); + +describe('discoverMcpTools', () => { + let mockConfig: Mocked; + let mockToolRegistry: Mocked; + + beforeEach(() => { + mockConfig = { + getMcpServers: vi.fn().mockReturnValue({}), + getMcpServerCommand: vi.fn().mockReturnValue(undefined), + } as any; + + mockToolRegistry = new (ToolRegistry as any)( + mockConfig, + ) as Mocked; + mockToolRegistry.registerTool = vi.fn(); + + vi.mocked(parse).mockClear(); + vi.mocked(Client).mockClear(); + vi.mocked(Client.prototype.connect) + .mockClear() + .mockResolvedValue(undefined); + vi.mocked(Client.prototype.listTools) + .mockClear() + .mockResolvedValue({ tools: [] }); + + vi.mocked(StdioClientTransport).mockClear(); + mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach + + vi.mocked(SSEClientTransport).mockClear(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should do nothing if no MCP servers or command are configured', async () => { + await discoverMcpTools(mockConfig, mockToolRegistry); + expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1); + expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1); + expect(Client).not.toHaveBeenCalled(); + expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); + }); + + it('should discover tools via mcpServerCommand', async () => { + const commandString = 'my-mcp-server --start'; + const parsedCommand = ['my-mcp-server', '--start'] as ParseEntry[]; + mockConfig.getMcpServerCommand.mockReturnValue(commandString); + vi.mocked(parse).mockReturnValue(parsedCommand); + + const mockTool = { + name: 'tool1', + description: 'desc1', + inputSchema: { type: 'object' as const, properties: {} }, + }; + vi.mocked(Client.prototype.listTools).mockResolvedValue({ + tools: [mockTool], + }); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + expect(parse).toHaveBeenCalledWith(commandString, process.env); + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: parsedCommand[0], + args: parsedCommand.slice(1), + env: expect.any(Object), + cwd: undefined, + stderr: 'pipe', + }); + expect(Client.prototype.connect).toHaveBeenCalledTimes(1); + expect(Client.prototype.listTools).toHaveBeenCalledTimes(1); + expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.any(DiscoveredMCPTool), + ); + const registeredTool = mockToolRegistry.registerTool.mock + .calls[0][0] as DiscoveredMCPTool; + expect(registeredTool.name).toBe('tool1'); + expect(registeredTool.serverToolName).toBe('tool1'); + }); + + it('should discover tools via mcpServers config (stdio)', async () => { + const serverConfig: MCPServerConfig = { + command: './mcp-stdio', + args: ['arg1'], + }; + mockConfig.getMcpServers.mockReturnValue({ 'stdio-server': serverConfig }); + + const mockTool = { + name: 'tool-stdio', + description: 'desc-stdio', + inputSchema: { type: 'object' as const, properties: {} }, + }; + vi.mocked(Client.prototype.listTools).mockResolvedValue({ + tools: [mockTool], + }); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + expect(StdioClientTransport).toHaveBeenCalledWith({ + command: serverConfig.command, + args: serverConfig.args, + env: expect.any(Object), + cwd: undefined, + stderr: 'pipe', + }); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.any(DiscoveredMCPTool), + ); + const registeredTool = mockToolRegistry.registerTool.mock + .calls[0][0] as DiscoveredMCPTool; + expect(registeredTool.name).toBe('tool-stdio'); + }); + + it('should discover tools via mcpServers config (sse)', async () => { + const serverConfig: MCPServerConfig = { url: 'http://localhost:1234/sse' }; + mockConfig.getMcpServers.mockReturnValue({ 'sse-server': serverConfig }); + + const mockTool = { + name: 'tool-sse', + description: 'desc-sse', + inputSchema: { type: 'object' as const, properties: {} }, + }; + vi.mocked(Client.prototype.listTools).mockResolvedValue({ + tools: [mockTool], + }); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!)); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.any(DiscoveredMCPTool), + ); + const registeredTool = mockToolRegistry.registerTool.mock + .calls[0][0] as DiscoveredMCPTool; + expect(registeredTool.name).toBe('tool-sse'); + }); + + it('should prefix tool names if multiple MCP servers are configured', async () => { + const serverConfig1: MCPServerConfig = { command: './mcp1' }; + const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' }; + mockConfig.getMcpServers.mockReturnValue({ + server1: serverConfig1, + server2: serverConfig2, + }); + + const mockTool1 = { + name: 'toolA', + description: 'd1', + inputSchema: { type: 'object' as const, properties: {} }, + }; + const mockTool2 = { + name: 'toolB', + description: 'd2', + inputSchema: { type: 'object' as const, properties: {} }, + }; + + vi.mocked(Client.prototype.listTools) + .mockResolvedValueOnce({ tools: [mockTool1] }) + .mockResolvedValueOnce({ tools: [mockTool2] }); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(2); + const registeredTool1 = mockToolRegistry.registerTool.mock + .calls[0][0] as DiscoveredMCPTool; + const registeredTool2 = mockToolRegistry.registerTool.mock + .calls[1][0] as DiscoveredMCPTool; + + expect(registeredTool1.name).toBe('server1__toolA'); + expect(registeredTool1.serverToolName).toBe('toolA'); + expect(registeredTool2.name).toBe('server2__toolB'); + expect(registeredTool2.serverToolName).toBe('toolB'); + }); + + it('should clean schema properties ($schema, additionalProperties)', async () => { + const serverConfig: MCPServerConfig = { command: './mcp-clean' }; + mockConfig.getMcpServers.mockReturnValue({ 'clean-server': serverConfig }); + + const rawSchema = { + type: 'object' as const, + $schema: 'http://json-schema.org/draft-07/schema#', + additionalProperties: true, + properties: { + prop1: { type: 'string', $schema: 'remove-this' }, + prop2: { + type: 'object' as const, + additionalProperties: false, + properties: { nested: { type: 'number' } }, + }, + }, + }; + const mockTool = { + name: 'cleanTool', + description: 'd', + inputSchema: JSON.parse(JSON.stringify(rawSchema)), + }; + vi.mocked(Client.prototype.listTools).mockResolvedValue({ + tools: [mockTool], + }); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1); + const registeredTool = mockToolRegistry.registerTool.mock + .calls[0][0] as DiscoveredMCPTool; + const cleanedParams = registeredTool.schema.parameters as any; + + expect(cleanedParams).not.toHaveProperty('$schema'); + expect(cleanedParams).not.toHaveProperty('additionalProperties'); + expect(cleanedParams.properties.prop1).not.toHaveProperty('$schema'); + expect(cleanedParams.properties.prop2).not.toHaveProperty( + 'additionalProperties', + ); + expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty( + '$schema', + ); + expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty( + 'additionalProperties', + ); + }); + + it('should handle error if mcpServerCommand parsing fails', async () => { + const commandString = 'my-mcp-server "unterminated quote'; + mockConfig.getMcpServerCommand.mockReturnValue(commandString); + vi.mocked(parse).mockImplementation(() => { + throw new Error('Parsing failed'); + }); + vi.spyOn(console, 'error').mockImplementation(() => {}); + + await expect( + discoverMcpTools(mockConfig, mockToolRegistry), + ).rejects.toThrow('Parsing failed'); + expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); + expect(console.error).not.toHaveBeenCalled(); + }); + + it('should log error and skip server if config is invalid (missing url and command)', async () => { + mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any }); + vi.spyOn(console, 'error').mockImplementation(() => {}); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining( + "MCP server 'bad-server' has invalid configuration", + ), + ); + // Client constructor should not be called if config is invalid before instantiation + expect(Client).not.toHaveBeenCalled(); + }); + + it('should log error and skip server if mcpClient.connect fails', async () => { + const serverConfig: MCPServerConfig = { command: './mcp-fail-connect' }; + mockConfig.getMcpServers.mockReturnValue({ + 'fail-connect-server': serverConfig, + }); + vi.mocked(Client.prototype.connect).mockRejectedValue( + new Error('Connection refused'), + ); + vi.spyOn(console, 'error').mockImplementation(() => {}); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining( + "failed to start or connect to MCP server 'fail-connect-server'", + ), + ); + expect(Client.prototype.listTools).not.toHaveBeenCalled(); + expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); + }); + + it('should log error and skip server if mcpClient.listTools fails', async () => { + const serverConfig: MCPServerConfig = { command: './mcp-fail-list' }; + mockConfig.getMcpServers.mockReturnValue({ + 'fail-list-server': serverConfig, + }); + vi.mocked(Client.prototype.listTools).mockRejectedValue( + new Error('ListTools error'), + ); + vi.spyOn(console, 'error').mockImplementation(() => {}); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + expect(console.error).toHaveBeenCalledWith( + expect.stringContaining( + "Failed to list or register tools for MCP server 'fail-list-server'", + ), + ); + expect(mockToolRegistry.registerTool).not.toHaveBeenCalled(); + }); + + it('should assign mcpClient.onerror handler', async () => { + const serverConfig: MCPServerConfig = { command: './mcp-onerror' }; + mockConfig.getMcpServers.mockReturnValue({ + 'onerror-server': serverConfig, + }); + + await discoverMcpTools(mockConfig, mockToolRegistry); + + const clientInstances = vi.mocked(Client).mock.results; + expect(clientInstances.length).toBeGreaterThan(0); + const lastClientInstance = + clientInstances[clientInstances.length - 1]?.value; + expect(lastClientInstance?.onerror).toEqual(expect.any(Function)); + }); +}); diff --git a/packages/server/src/tools/mcp-client.ts b/packages/server/src/tools/mcp-client.ts new file mode 100644 index 00000000..8c2b4879 --- /dev/null +++ b/packages/server/src/tools/mcp-client.ts @@ -0,0 +1,138 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; +import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; +import { parse } from 'shell-quote'; +import { Config, MCPServerConfig } from '../config/config.js'; +import { DiscoveredMCPTool } from './mcp-tool.js'; +import { ToolRegistry } from './tool-registry.js'; + +export async function discoverMcpTools( + config: Config, + toolRegistry: ToolRegistry, +): Promise { + const mcpServers = config.getMcpServers() || {}; + + if (config.getMcpServerCommand()) { + const cmd = config.getMcpServerCommand()!; + const args = parse(cmd, process.env) as string[]; + if (args.some((arg) => typeof arg !== 'string')) { + throw new Error('failed to parse mcpServerCommand: ' + cmd); + } + // use generic server name 'mcp' + mcpServers['mcp'] = { + command: args[0], + args: args.slice(1), + }; + } + + const discoveryPromises = Object.entries(mcpServers).map( + ([mcpServerName, mcpServerConfig]) => + connectAndDiscover( + mcpServerName, + mcpServerConfig, + toolRegistry, + mcpServers, + ), + ); + await Promise.all(discoveryPromises); +} + +async function connectAndDiscover( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + toolRegistry: ToolRegistry, + mcpServers: Record, +): Promise { + let transport; + if (mcpServerConfig.url) { + transport = new SSEClientTransport(new URL(mcpServerConfig.url)); + } else if (mcpServerConfig.command) { + transport = new StdioClientTransport({ + command: mcpServerConfig.command, + args: mcpServerConfig.args || [], + env: { + ...process.env, + ...(mcpServerConfig.env || {}), + } as Record, + cwd: mcpServerConfig.cwd, + stderr: 'pipe', + }); + } else { + console.error( + `MCP server '${mcpServerName}' has invalid configuration: missing both url (for SSE) and command (for stdio). Skipping.`, + ); + return; // Return a resolved promise as this path doesn't throw. + } + + const mcpClient = new Client({ + name: 'gemini-cli-mcp-client', + version: '0.0.1', + }); + + try { + await mcpClient.connect(transport); + } catch (error) { + console.error( + `failed to start or connect to MCP server '${mcpServerName}' ` + + `${JSON.stringify(mcpServerConfig)}; \n${error}`, + ); + return; // Return a resolved promise, let other MCP servers be discovered. + } + + mcpClient.onerror = (error) => { + console.error('MCP ERROR', error.toString()); + }; + + if (transport instanceof StdioClientTransport && transport.stderr) { + transport.stderr.on('data', (data) => { + if (!data.toString().includes('] INFO')) { + console.debug('MCP STDERR', data.toString()); + } + }); + } + + try { + const result = await mcpClient.listTools(); + for (const tool of result.tools) { + // Recursively remove additionalProperties and $schema from the inputSchema + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- This function recursively navigates a deeply nested and potentially heterogeneous JSON schema object. Using 'any' is a pragmatic choice here to avoid overly complex type definitions for all possible schema variations. + const removeSchemaProps = (obj: any) => { + if (typeof obj !== 'object' || obj === null) { + return; + } + if (Array.isArray(obj)) { + obj.forEach(removeSchemaProps); + } else { + delete obj.additionalProperties; + delete obj.$schema; + Object.values(obj).forEach(removeSchemaProps); + } + }; + removeSchemaProps(tool.inputSchema); + + toolRegistry.registerTool( + new DiscoveredMCPTool( + mcpClient, + Object.keys(mcpServers).length > 1 + ? mcpServerName + '__' + tool.name + : tool.name, + tool.description ?? '', + tool.inputSchema, + tool.name, + mcpServerConfig.timeout, + ), + ); + } + } catch (error) { + console.error( + `Failed to list or register tools for MCP server '${mcpServerName}': ${error}`, + ); + // Do not re-throw, allow other servers to proceed. + } +} diff --git a/packages/server/src/tools/mcp-tool.test.ts b/packages/server/src/tools/mcp-tool.test.ts new file mode 100644 index 00000000..e28cf586 --- /dev/null +++ b/packages/server/src/tools/mcp-tool.test.ts @@ -0,0 +1,161 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + Mocked, +} from 'vitest'; +import { + DiscoveredMCPTool, + MCP_TOOL_DEFAULT_TIMEOUT_MSEC, +} from './mcp-tool.js'; +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { ToolResult } from './tools.js'; + +// Mock MCP SDK Client +vi.mock('@modelcontextprotocol/sdk/client/index.js', () => { + const MockClient = vi.fn(); + MockClient.prototype.callTool = vi.fn(); + return { Client: MockClient }; +}); + +describe('DiscoveredMCPTool', () => { + let mockMcpClient: Mocked; + const toolName = 'test-mcp-tool'; + const serverToolName = 'actual-server-tool-name'; + const baseDescription = 'A test MCP tool.'; + const inputSchema = { + type: 'object' as const, + properties: { param: { type: 'string' } }, + }; + + beforeEach(() => { + // Create a new mock client for each test to reset call history + mockMcpClient = new (Client as any)({ + name: 'test-client', + version: '0.0.1', + }) as Mocked; + vi.mocked(mockMcpClient.callTool).mockClear(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('constructor', () => { + it('should set properties correctly and augment description', () => { + const tool = new DiscoveredMCPTool( + mockMcpClient, + toolName, + baseDescription, + inputSchema, + serverToolName, + ); + + expect(tool.name).toBe(toolName); + expect(tool.schema.name).toBe(toolName); + expect(tool.schema.description).toContain(baseDescription); + expect(tool.schema.description).toContain('This MCP tool was discovered'); + // Corrected assertion for backticks and template literal + expect(tool.schema.description).toContain( + `tools/call\` method for tool name \`${toolName}\``, + ); + expect(tool.schema.parameters).toEqual(inputSchema); + expect(tool.serverToolName).toBe(serverToolName); + expect(tool.timeout).toBeUndefined(); + }); + + it('should accept and store a custom timeout', () => { + const customTimeout = 5000; + const tool = new DiscoveredMCPTool( + mockMcpClient, + toolName, + baseDescription, + inputSchema, + serverToolName, + customTimeout, + ); + expect(tool.timeout).toBe(customTimeout); + }); + }); + + describe('execute', () => { + it('should call mcpClient.callTool with correct parameters and default timeout', async () => { + const tool = new DiscoveredMCPTool( + mockMcpClient, + toolName, + baseDescription, + inputSchema, + serverToolName, + ); + const params = { param: 'testValue' }; + const expectedMcpResult = { success: true, details: 'executed' }; + vi.mocked(mockMcpClient.callTool).mockResolvedValue(expectedMcpResult); + + const result: ToolResult = await tool.execute(params); + + expect(mockMcpClient.callTool).toHaveBeenCalledWith( + { + name: serverToolName, + arguments: params, + }, + undefined, + { + timeout: MCP_TOOL_DEFAULT_TIMEOUT_MSEC, + }, + ); + const expectedOutput = JSON.stringify(expectedMcpResult, null, 2); + expect(result.llmContent).toBe(expectedOutput); + expect(result.returnDisplay).toBe(expectedOutput); + }); + + it('should call mcpClient.callTool with custom timeout if provided', async () => { + const customTimeout = 15000; + const tool = new DiscoveredMCPTool( + mockMcpClient, + toolName, + baseDescription, + inputSchema, + serverToolName, + customTimeout, + ); + const params = { param: 'anotherValue' }; + const expectedMcpResult = { result: 'done' }; + vi.mocked(mockMcpClient.callTool).mockResolvedValue(expectedMcpResult); + + await tool.execute(params); + + expect(mockMcpClient.callTool).toHaveBeenCalledWith( + expect.anything(), + undefined, + { + timeout: customTimeout, + }, + ); + }); + + it('should propagate rejection if mcpClient.callTool rejects', async () => { + const tool = new DiscoveredMCPTool( + mockMcpClient, + toolName, + baseDescription, + inputSchema, + serverToolName, + ); + const params = { param: 'failCase' }; + const expectedError = new Error('MCP call failed'); + vi.mocked(mockMcpClient.callTool).mockRejectedValue(expectedError); + + await expect(tool.execute(params)).rejects.toThrow(expectedError); + }); + }); +}); diff --git a/packages/server/src/tools/mcp-tool.ts b/packages/server/src/tools/mcp-tool.ts new file mode 100644 index 00000000..05ad750c --- /dev/null +++ b/packages/server/src/tools/mcp-tool.ts @@ -0,0 +1,49 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import { BaseTool, ToolResult } from './tools.js'; + +type ToolParams = Record; + +export const MCP_TOOL_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes + +export class DiscoveredMCPTool extends BaseTool { + constructor( + private readonly mcpClient: Client, + readonly name: string, + readonly description: string, + readonly parameterSchema: Record, + readonly serverToolName: string, + readonly timeout?: number, + ) { + description += ` + +This MCP tool was discovered from a local MCP server using JSON RPC 2.0 over stdio transport protocol. +When called, this tool will invoke the \`tools/call\` method for tool name \`${name}\`. +MCP servers can be configured in project or user settings. +Returns the MCP server response as a json string. +`; + super(name, name, description, parameterSchema); + } + + async execute(params: ToolParams): Promise { + const result = await this.mcpClient.callTool( + { + name: this.serverToolName, + arguments: params, + }, + undefined, // skip resultSchema to specify options (RequestOptions) + { + timeout: this.timeout ?? MCP_TOOL_DEFAULT_TIMEOUT_MSEC, + }, + ); + return { + llmContent: JSON.stringify(result, null, 2), + returnDisplay: JSON.stringify(result, null, 2), + }; + } +} diff --git a/packages/server/src/tools/tool-registry.test.ts b/packages/server/src/tools/tool-registry.test.ts index 4c2bff38..bb41b35c 100644 --- a/packages/server/src/tools/tool-registry.test.ts +++ b/packages/server/src/tools/tool-registry.test.ts @@ -14,11 +14,8 @@ import { afterEach, Mocked, } from 'vitest'; -import { - ToolRegistry, - DiscoveredTool, - DiscoveredMCPTool, -} from './tool-registry.js'; +import { ToolRegistry, DiscoveredTool } from './tool-registry.js'; +import { DiscoveredMCPTool } from './mcp-tool.js'; import { Config } from '../config/config.js'; import { BaseTool, ToolResult } from './tools.js'; import { FunctionDeclaration } from '@google/genai'; @@ -347,7 +344,7 @@ describe('ToolRegistry', () => { toolRegistry = new ToolRegistry(config); }); - it('should discover tools using discovery command', () => { + it('should discover tools using discovery command', async () => { const discoveryCommand = 'my-discovery-command'; mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); const mockToolDeclarations: FunctionDeclaration[] = [ @@ -366,7 +363,7 @@ describe('ToolRegistry', () => { ), ); - toolRegistry.discoverTools(); + await toolRegistry.discoverTools(); expect(execSync).toHaveBeenCalledWith(discoveryCommand); const discoveredTool = toolRegistry.getTool('discovered-tool-1'); @@ -376,7 +373,7 @@ describe('ToolRegistry', () => { expect(discoveredTool?.description).toContain(discoveryCommand); }); - it('should remove previously discovered tools before discovering new ones', () => { + it('should remove previously discovered tools before discovering new ones', async () => { const discoveryCommand = 'my-discovery-command'; mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); mockExecSync.mockReturnValueOnce( @@ -394,7 +391,7 @@ describe('ToolRegistry', () => { ]), ), ); - toolRegistry.discoverTools(); + await toolRegistry.discoverTools(); expect(toolRegistry.getTool('old-discovered-tool')).toBeInstanceOf( DiscoveredTool, ); @@ -414,7 +411,7 @@ describe('ToolRegistry', () => { ]), ), ); - toolRegistry.discoverTools(); + await toolRegistry.discoverTools(); expect(toolRegistry.getTool('old-discovered-tool')).toBeUndefined(); expect(toolRegistry.getTool('new-discovered-tool')).toBeInstanceOf( DiscoveredTool, @@ -457,8 +454,7 @@ describe('ToolRegistry', () => { }); mockMcpClientInstance.connect.mockResolvedValue(undefined); - toolRegistry.discoverTools(); - await new Promise((resolve) => setTimeout(resolve, 100)); // Wait for async operations + await toolRegistry.discoverTools(); expect(Client).toHaveBeenCalledTimes(1); expect(StdioClientTransport).toHaveBeenCalledWith({ @@ -511,8 +507,7 @@ describe('ToolRegistry', () => { }); mockMcpClientInstance.connect.mockResolvedValue(undefined); - toolRegistry.discoverTools(); - await new Promise((resolve) => setTimeout(resolve, 100)); + await toolRegistry.discoverTools(); expect(Client).toHaveBeenCalledTimes(1); expect(StdioClientTransport).toHaveBeenCalledWith({ @@ -544,8 +539,7 @@ describe('ToolRegistry', () => { // Need to await the async IIFE within discoverTools. // Since discoverTools itself isn't async, we can't directly await it. // We'll check the console.error mock. - toolRegistry.discoverTools(); - await new Promise((resolve) => setTimeout(resolve, 100)); // Wait for async operations + await toolRegistry.discoverTools(); expect(console.error).toHaveBeenCalledWith( `failed to start or connect to MCP server 'failing-mcp' ${JSON.stringify({ command: 'fail-cmd' })}; \nError: Connection failed`, diff --git a/packages/server/src/tools/tool-registry.ts b/packages/server/src/tools/tool-registry.ts index 7b75e0f2..a2677e63 100644 --- a/packages/server/src/tools/tool-registry.ts +++ b/packages/server/src/tools/tool-registry.ts @@ -7,15 +7,11 @@ import { FunctionDeclaration } from '@google/genai'; import { Tool, ToolResult, BaseTool } from './tools.js'; import { Config } from '../config/config.js'; -import { parse } from 'shell-quote'; import { spawn, execSync } from 'node:child_process'; -// TODO: remove this dependency once MCP support is built into genai SDK -import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; -import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; -type ToolParams = Record; +import { discoverMcpTools } from './mcp-client.js'; +import { DiscoveredMCPTool } from './mcp-tool.js'; -const MCP_TOOL_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes +type ToolParams = Record; export class DiscoveredTool extends BaseTool { constructor( @@ -95,43 +91,6 @@ Signal: Signal number or \`(none)\` if no signal was received. } } -export class DiscoveredMCPTool extends BaseTool { - constructor( - private readonly mcpClient: Client, - readonly name: string, - readonly description: string, - readonly parameterSchema: Record, - readonly serverToolName: string, - readonly timeout?: number, - ) { - description += ` - -This MCP tool was discovered from a local MCP server using JSON RPC 2.0 over stdio transport protocol. -When called, this tool will invoke the \`tools/call\` method for tool name \`${name}\`. -MCP servers can be configured in project or user settings. -Returns the MCP server response as a json string. -`; - super(name, name, description, parameterSchema); - } - - async execute(params: ToolParams): Promise { - const result = await this.mcpClient.callTool( - { - name: this.serverToolName, - arguments: params, - }, - undefined, // skip resultSchema to specify options (RequestOptions) - { - timeout: this.timeout ?? MCP_TOOL_DEFAULT_TIMEOUT_MSEC, - }, - ); - return { - llmContent: JSON.stringify(result, null, 2), - returnDisplay: JSON.stringify(result, null, 2), - }; - } -} - export class ToolRegistry { private tools: Map = new Map(); private config: Config; @@ -158,11 +117,13 @@ export class ToolRegistry { * Discovers tools from project, if a discovery command is configured. * Can be called multiple times to update discovered tools. */ - discoverTools(): void { + async discoverTools(): Promise { // remove any previously discovered tools for (const tool of this.tools.values()) { - if (tool instanceof DiscoveredTool) { + if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) { this.tools.delete(tool.name); + } else { + // Keep manually registered tools } } // discover tools using discovery command, if configured @@ -186,106 +147,7 @@ export class ToolRegistry { } } // discover tools using MCP servers, if configured - // convert mcpServerCommand (if any) to StdioServerParameters - const mcpServers = this.config.getMcpServers() || {}; - - if (this.config.getMcpServerCommand()) { - const cmd = this.config.getMcpServerCommand()!; - const args = parse(cmd, process.env) as string[]; - if (args.some((arg) => typeof arg !== 'string')) { - throw new Error('failed to parse mcpServerCommand: ' + cmd); - } - // use generic server name 'mcp' - mcpServers['mcp'] = { - command: args[0], - args: args.slice(1), - }; - } - for (const [mcpServerName, mcpServerConfig] of Object.entries(mcpServers)) { - (async () => { - const mcpClient = new Client({ - name: 'mcp-client', - version: '0.0.1', - }); - let transport; - if (mcpServerConfig.url) { - // SSE transport if URL is provided - transport = new SSEClientTransport(new URL(mcpServerConfig.url)); - } else if (mcpServerConfig.command) { - // Stdio transport if command is provided - transport = new StdioClientTransport({ - command: mcpServerConfig.command, - args: mcpServerConfig.args || [], - env: { - ...process.env, - ...(mcpServerConfig.env || {}), - } as Record, - cwd: mcpServerConfig.cwd, - stderr: 'pipe', - }); - } else { - console.error( - `MCP server '${mcpServerName}' has invalid configuration: missing both url (for SSE) and command (for stdio). Skipping.`, - ); - return; - } - try { - await mcpClient.connect(transport); - } catch (error) { - console.error( - `failed to start or connect to MCP server '${mcpServerName}' ` + - `${JSON.stringify(mcpServerConfig)}; \n${error}`, - ); - // Do not re-throw, let other MCP servers be discovered. - return; // Exit this async IIFE if connection failed - } - mcpClient.onerror = (error) => { - console.error('MCP ERROR', error.toString()); - }; - if (transport instanceof StdioClientTransport && !transport.stderr) { - throw new Error('transport missing stderr stream'); - } - if (transport instanceof StdioClientTransport) { - transport.stderr!.on('data', (data) => { - // filter out INFO messages logged for each request received - if (!data.toString().includes('] INFO')) { - console.debug('MCP STDERR', data.toString()); - } - }); - } - const result = await mcpClient.listTools(); - for (const tool of result.tools) { - // Recursively remove additionalProperties and $schema from the inputSchema - // eslint-disable-next-line @typescript-eslint/no-explicit-any -- This function recursively navigates a deeply nested and potentially heterogeneous JSON schema object. Using 'any' is a pragmatic choice here to avoid overly complex type definitions for all possible schema variations. - const removeSchemaProps = (obj: any) => { - if (typeof obj !== 'object' || obj === null) { - return; - } - if (Array.isArray(obj)) { - obj.forEach(removeSchemaProps); - } else { - delete obj.additionalProperties; - delete obj.$schema; - Object.values(obj).forEach(removeSchemaProps); - } - }; - removeSchemaProps(tool.inputSchema); - - this.registerTool( - new DiscoveredMCPTool( - mcpClient, - Object.keys(mcpServers).length > 1 - ? mcpServerName + '__' + tool.name - : tool.name, - tool.description ?? '', - tool.inputSchema, - tool.name, - mcpServerConfig.timeout, - ), - ); - } - })(); - } + await discoverMcpTools(this.config, this); } /**