From b24c5887c45edde8690b4d73d8961e63eee13a34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20Medrano=20Llamas?= <45878745+rmedranollamas@users.noreply.github.com> Date: Tue, 19 Aug 2025 21:03:19 +0200 Subject: [PATCH] feat: restart MCP servers on /mcp refresh (#5479) Co-authored-by: Brian Ray <62354532+emeryray2002@users.noreply.github.com> Co-authored-by: N. Taylor Mullen --- .../cli/src/ui/commands/mcpCommand.test.ts | 5 +- packages/cli/src/ui/commands/mcpCommand.ts | 6 +- .../src/ui/hooks/atCommandProcessor.test.ts | 6 + .../core/src/tools/mcp-client-manager.test.ts | 54 ++ packages/core/src/tools/mcp-client-manager.ts | 115 ++++ packages/core/src/tools/mcp-client.test.ts | 503 ++++-------------- packages/core/src/tools/mcp-client.ts | 130 ++++- packages/core/src/tools/tool-registry.test.ts | 52 +- packages/core/src/tools/tool-registry.ts | 43 +- 9 files changed, 447 insertions(+), 467 deletions(-) create mode 100644 packages/core/src/tools/mcp-client-manager.test.ts create mode 100644 packages/core/src/tools/mcp-client-manager.ts diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index 6e48c2f9..09b97bb0 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -972,6 +972,7 @@ describe('mcpCommand', () => { it('should refresh the list of tools and display the status', async () => { const mockToolRegistry = { discoverMcpTools: vi.fn(), + restartMcpServers: vi.fn(), getAllTools: vi.fn().mockReturnValue([]), }; const mockGeminiClient = { @@ -1004,11 +1005,11 @@ describe('mcpCommand', () => { expect(context.ui.addItem).toHaveBeenCalledWith( { type: 'info', - text: 'Refreshing MCP servers and tools...', + text: 'Restarting MCP servers...', }, expect.any(Number), ); - expect(mockToolRegistry.discoverMcpTools).toHaveBeenCalled(); + expect(mockToolRegistry.restartMcpServers).toHaveBeenCalled(); expect(mockGeminiClient.setTools).toHaveBeenCalled(); expect(context.ui.reloadCommands).toHaveBeenCalledTimes(1); diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 686102be..9e321937 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -471,7 +471,7 @@ const listCommand: SlashCommand = { const refreshCommand: SlashCommand = { name: 'refresh', - description: 'Refresh the list of MCP servers and tools', + description: 'Restarts MCP servers.', kind: CommandKind.BUILT_IN, action: async ( context: CommandContext, @@ -497,12 +497,12 @@ const refreshCommand: SlashCommand = { context.ui.addItem( { type: 'info', - text: 'Refreshing MCP servers and tools...', + text: 'Restarting MCP servers...', }, Date.now(), ); - await toolRegistry.discoverMcpTools(); + await toolRegistry.restartMcpServers(); // Update the client with the new tools const geminiClient = config.getGeminiClient(); diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts index 5509d9ff..7403f788 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts @@ -63,6 +63,12 @@ describe('handleAtCommand', () => { isPathWithinWorkspace: () => true, getDirectories: () => [testRootDir], }), + getMcpServers: () => ({}), + getMcpServerCommand: () => undefined, + getPromptRegistry: () => ({ + getPromptsByServer: () => [], + }), + getDebugMode: () => false, } as unknown as Config; const registry = new ToolRegistry(mockConfig); diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts new file mode 100644 index 00000000..3dba197f --- /dev/null +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { afterEach, describe, expect, it, vi } from 'vitest'; +import { McpClientManager } from './mcp-client-manager.js'; +import { McpClient } from './mcp-client.js'; +import { ToolRegistry } from './tool-registry.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; +import { WorkspaceContext } from '../utils/workspaceContext.js'; + +vi.mock('./mcp-client.js', async () => { + const originalModule = await vi.importActual('./mcp-client.js'); + return { + ...originalModule, + McpClient: vi.fn(), + populateMcpServerCommand: vi.fn(() => ({ + 'test-server': {}, + })), + }; +}); + +describe('McpClientManager', () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should discover tools from all servers', async () => { + const mockedMcpClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + }; + vi.mocked(McpClient).mockReturnValue( + mockedMcpClient as unknown as McpClient, + ); + const manager = new McpClientManager( + { + 'test-server': {}, + }, + '', + {} as ToolRegistry, + {} as PromptRegistry, + false, + {} as WorkspaceContext, + ); + await manager.discoverAllMcpTools(); + expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + }); +}); diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts new file mode 100644 index 00000000..c22afb8f --- /dev/null +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -0,0 +1,115 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { MCPServerConfig } from '../config/config.js'; +import { ToolRegistry } from './tool-registry.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; +import { + McpClient, + MCPDiscoveryState, + populateMcpServerCommand, +} from './mcp-client.js'; +import { getErrorMessage } from '../utils/errors.js'; +import { WorkspaceContext } from '../utils/workspaceContext.js'; + +/** + * Manages the lifecycle of multiple MCP clients, including local child processes. + * This class is responsible for starting, stopping, and discovering tools from + * a collection of MCP servers defined in the configuration. + */ +export class McpClientManager { + private clients: Map = new Map(); + private readonly mcpServers: Record; + private readonly mcpServerCommand: string | undefined; + private readonly toolRegistry: ToolRegistry; + private readonly promptRegistry: PromptRegistry; + private readonly debugMode: boolean; + private readonly workspaceContext: WorkspaceContext; + private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; + + constructor( + mcpServers: Record, + mcpServerCommand: string | undefined, + toolRegistry: ToolRegistry, + promptRegistry: PromptRegistry, + debugMode: boolean, + workspaceContext: WorkspaceContext, + ) { + this.mcpServers = mcpServers; + this.mcpServerCommand = mcpServerCommand; + this.toolRegistry = toolRegistry; + this.promptRegistry = promptRegistry; + this.debugMode = debugMode; + this.workspaceContext = workspaceContext; + } + + /** + * Initiates the tool discovery process for all configured MCP servers. + * It connects to each server, discovers its available tools, and registers + * them with the `ToolRegistry`. + */ + async discoverAllMcpTools(): Promise { + await this.stop(); + this.discoveryState = MCPDiscoveryState.IN_PROGRESS; + const servers = populateMcpServerCommand( + this.mcpServers, + this.mcpServerCommand, + ); + + const discoveryPromises = Object.entries(servers).map( + async ([name, config]) => { + const client = new McpClient( + name, + config, + this.toolRegistry, + this.promptRegistry, + this.workspaceContext, + this.debugMode, + ); + this.clients.set(name, client); + try { + await client.connect(); + await client.discover(); + } catch (error) { + // Log the error but don't let a single failed server stop the others + console.error( + `Error during discovery for server '${name}': ${getErrorMessage( + error, + )}`, + ); + } + }, + ); + + await Promise.all(discoveryPromises); + this.discoveryState = MCPDiscoveryState.COMPLETED; + } + + /** + * Stops all running local MCP servers and closes all client connections. + * This is the cleanup method to be called on application exit. + */ + async stop(): Promise { + const disconnectionPromises = Array.from(this.clients.entries()).map( + async ([name, client]) => { + try { + await client.disconnect(); + } catch (error) { + console.error( + `Error stopping client '${name}': ${getErrorMessage(error)}`, + ); + } + }, + ); + + await Promise.all(disconnectionPromises); + this.clients.clear(); + } + + getDiscoveryState(): MCPDiscoveryState { + return this.discoveryState; + } +} diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 3467ad95..b8f61856 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -4,16 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest'; +import { afterEach, describe, expect, it, vi } from 'vitest'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { populateMcpServerCommand, createTransport, isEnabled, - discoverTools, - discoverPrompts, hasValidTypes, - connectToMcpServer, + McpClient, } from './mcp-client.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -22,26 +20,36 @@ import * as GenAiLib from '@google/genai'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { AuthProviderType } from '../config/config.js'; import { PromptRegistry } from '../prompts/prompt-registry.js'; - -import { DiscoveredMCPTool } from './mcp-tool.js'; +import { ToolRegistry } from './tool-registry.js'; import { WorkspaceContext } from '../utils/workspaceContext.js'; -import { pathToFileURL } from 'node:url'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); vi.mock('@google/genai'); vi.mock('../mcp/oauth-provider.js'); vi.mock('../mcp/oauth-token-storage.js'); -vi.mock('./mcp-tool.js'); describe('mcp-client', () => { afterEach(() => { vi.restoreAllMocks(); }); - describe('discoverTools', () => { + describe('McpClient', () => { it('should discover tools', async () => { - const mockedClient = {} as unknown as ClientLib.Client; + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ tool: () => ({ functionDeclarations: [ @@ -51,62 +59,43 @@ describe('mcp-client', () => { ], }), } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); + const mockedToolRegistry = { + registerTool: vi.fn(), + } as unknown as ToolRegistry; + const client = new McpClient( + 'test-server', + { + command: 'test-command', + }, + mockedToolRegistry, + {} as PromptRegistry, + {} as WorkspaceContext, + false, + ); + await client.connect(); + await client.discover(); expect(mockedMcpToTool).toHaveBeenCalledOnce(); }); - it('should log an error if there is an error discovering a tool', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleErrorSpy = vi - .spyOn(console, 'error') - .mockImplementation(() => {}); - - const testError = new Error('Invalid tool name'); - vi.mocked(DiscoveredMCPTool).mockImplementation( - ( - _mcpCallableTool: GenAiLib.CallableTool, - _serverName: string, - name: string, - ) => { - if (name === 'invalid tool name') { - throw testError; - } - return { name: 'validTool' } as DiscoveredMCPTool; - }, - ); - - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'validTool', - }, - { - name: 'invalid tool name', // this will fail validation - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(tools[0].name).toBe('validTool'); - expect(consoleErrorSpy).toHaveBeenCalledOnce(); - expect(consoleErrorSpy).toHaveBeenCalledWith( - `Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`, - ); - }); - it('should skip tools if a parameter is missing a type', async () => { - const mockedClient = {} as unknown as ClientLib.Client; const consoleWarnSpy = vi .spyOn(console, 'warn') .mockImplementation(() => {}); + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + tool: vi.fn(), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ tool: () => Promise.resolve({ @@ -132,352 +121,73 @@ describe('mcp-client', () => { ], }), } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool'); - expect(consoleWarnSpy).toHaveBeenCalledOnce(); - expect(consoleWarnSpy).toHaveBeenCalledWith( - `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` + - `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`, - ); - consoleWarnSpy.mockRestore(); - }); - - it('should skip tools if a nested parameter is missing a type', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'invalidTool', - parametersJsonSchema: { - type: 'object', - properties: { - param1: { - type: 'object', - properties: { - nestedParam: { - description: 'a nested param with no type', - }, - }, - }, - }, - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(0); - expect(consoleWarnSpy).toHaveBeenCalledOnce(); - expect(consoleWarnSpy).toHaveBeenCalledWith( - `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` + - `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`, - ); - consoleWarnSpy.mockRestore(); - }); - - it('should skip tool if an array item is missing a type', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'invalidTool', - parametersJsonSchema: { - type: 'object', - properties: { - param1: { - type: 'array', - items: { - description: 'an array item with no type', - }, - }, - }, - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(0); - expect(consoleWarnSpy).toHaveBeenCalledOnce(); - expect(consoleWarnSpy).toHaveBeenCalledWith( - `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` + - `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`, - ); - consoleWarnSpy.mockRestore(); - }); - - it('should discover tool with no properties in schema', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'validTool', - parametersJsonSchema: { - type: 'object', - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool'); - expect(consoleWarnSpy).not.toHaveBeenCalled(); - consoleWarnSpy.mockRestore(); - }); - - it('should discover tool with empty properties object in schema', async () => { - const mockedClient = {} as unknown as ClientLib.Client; - const consoleWarnSpy = vi - .spyOn(console, 'warn') - .mockImplementation(() => {}); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'validTool', - parametersJsonSchema: { - type: 'object', - properties: {}, - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); - - const tools = await discoverTools('test-server', {}, mockedClient); - - expect(tools.length).toBe(1); - expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool'); - expect(consoleWarnSpy).not.toHaveBeenCalled(); - consoleWarnSpy.mockRestore(); - }); - }); - - describe('connectToMcpServer', () => { - it('should send a notification when directories change', async () => { - const mockedClient = { - registerCapabilities: vi.fn(), - setRequestHandler: vi.fn(), - notification: vi.fn(), - callTool: vi.fn(), - connect: vi.fn(), - }; - vi.mocked(ClientLib.Client).mockReturnValue( - mockedClient as unknown as ClientLib.Client, - ); - vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( - {} as SdkClientStdioLib.StdioClientTransport, - ); - let onDirectoriesChangedCallback: () => void = () => {}; - const mockWorkspaceContext = { - getDirectories: vi - .fn() - .mockReturnValue(['/test/dir', '/another/project']), - onDirectoriesChanged: vi.fn().mockImplementation((callback) => { - onDirectoriesChangedCallback = callback; - }), - } as unknown as WorkspaceContext; - - await connectToMcpServer( + const mockedToolRegistry = { + registerTool: vi.fn(), + } as unknown as ToolRegistry; + const client = new McpClient( 'test-server', { command: 'test-command', }, + mockedToolRegistry, + {} as PromptRegistry, + {} as WorkspaceContext, false, - mockWorkspaceContext, ); - - onDirectoriesChangedCallback(); - - expect(mockedClient.notification).toHaveBeenCalledWith({ - method: 'notifications/roots/list_changed', - }); + await client.connect(); + await client.discover(); + expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); + expect(consoleWarnSpy).toHaveBeenCalledOnce(); + expect(consoleWarnSpy).toHaveBeenCalledWith( + `Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` + + `missing types in its parameter schema. Please file an issue with the owner of the MCP server.`, + ); + consoleWarnSpy.mockRestore(); }); - it('should register a roots/list handler', async () => { - const mockedClient = { - registerCapabilities: vi.fn(), - setRequestHandler: vi.fn(), - callTool: vi.fn(), - connect: vi.fn(), - }; - vi.mocked(ClientLib.Client).mockReturnValue( - mockedClient as unknown as ClientLib.Client, - ); - vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( - {} as SdkClientStdioLib.StdioClientTransport, - ); - const mockWorkspaceContext = { - getDirectories: vi - .fn() - .mockReturnValue(['/test/dir', '/another/project']), - onDirectoriesChanged: vi.fn(), - } as unknown as WorkspaceContext; - - await connectToMcpServer( - 'test-server', - { - command: 'test-command', - }, - false, - mockWorkspaceContext, - ); - - expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({ - roots: { - listChanged: true, - }, - }); - expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce(); - const handler = mockedClient.setRequestHandler.mock.calls[0][1]; - const roots = await handler(); - expect(roots).toEqual({ - roots: [ - { - uri: pathToFileURL('/test/dir').toString(), - name: 'dir', - }, - { - uri: pathToFileURL('/another/project').toString(), - name: 'project', - }, - ], - }); - }); - }); - - describe('discoverPrompts', () => { - const mockedPromptRegistry = { - registerPrompt: vi.fn(), - } as unknown as PromptRegistry; - - it('should discover and log prompts', async () => { - const mockRequest = vi.fn().mockResolvedValue({ - prompts: [ - { name: 'prompt1', description: 'desc1' }, - { name: 'prompt2' }, - ], - }); - const mockGetServerCapabilities = vi.fn().mockReturnValue({ - prompts: {}, - }); - const mockedClient = { - getServerCapabilities: mockGetServerCapabilities, - request: mockRequest, - } as unknown as ClientLib.Client; - - await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); - - expect(mockGetServerCapabilities).toHaveBeenCalledOnce(); - expect(mockRequest).toHaveBeenCalledWith( - { method: 'prompts/list', params: {} }, - expect.anything(), - ); - }); - - it('should do nothing if no prompts are discovered', async () => { - const mockRequest = vi.fn().mockResolvedValue({ - prompts: [], - }); - const mockGetServerCapabilities = vi.fn().mockReturnValue({ - prompts: {}, - }); - - const mockedClient = { - getServerCapabilities: mockGetServerCapabilities, - request: mockRequest, - } as unknown as ClientLib.Client; - - const consoleLogSpy = vi - .spyOn(console, 'debug') - .mockImplementation(() => {}); - - await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); - - expect(mockGetServerCapabilities).toHaveBeenCalledOnce(); - expect(mockRequest).toHaveBeenCalledOnce(); - expect(consoleLogSpy).not.toHaveBeenCalled(); - - consoleLogSpy.mockRestore(); - }); - - it('should do nothing if the server has no prompt support', async () => { - const mockRequest = vi.fn().mockResolvedValue({ - prompts: [], - }); - const mockGetServerCapabilities = vi.fn().mockReturnValue({}); - - const mockedClient = { - getServerCapabilities: mockGetServerCapabilities, - request: mockRequest, - } as unknown as ClientLib.Client; - - const consoleLogSpy = vi - .spyOn(console, 'debug') - .mockImplementation(() => {}); - - await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); - - expect(mockGetServerCapabilities).toHaveBeenCalledOnce(); - expect(mockRequest).not.toHaveBeenCalled(); - expect(consoleLogSpy).not.toHaveBeenCalled(); - - consoleLogSpy.mockRestore(); - }); - - it('should log an error if discovery fails', async () => { - const testError = new Error('test error'); - testError.message = 'test error'; - const mockRequest = vi.fn().mockRejectedValue(testError); - const mockGetServerCapabilities = vi.fn().mockReturnValue({ - prompts: {}, - }); - const mockedClient = { - getServerCapabilities: mockGetServerCapabilities, - request: mockRequest, - } as unknown as ClientLib.Client; - + it('should handle errors when discovering prompts', async () => { const consoleErrorSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); - - await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); - - expect(mockRequest).toHaveBeenCalledOnce(); - expect(consoleErrorSpy).toHaveBeenCalledWith( - `Error discovering prompts from test-server: ${testError.message}`, + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }), + request: vi.fn().mockRejectedValue(new Error('Test error')), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ + tool: () => Promise.resolve({ functionDeclarations: [] }), + } as unknown as GenAiLib.CallableTool); + const client = new McpClient( + 'test-server', + { + command: 'test-command', + }, + {} as ToolRegistry, + {} as PromptRegistry, + {} as WorkspaceContext, + false, + ); + await client.connect(); + await expect(client.discover()).rejects.toThrow( + 'No prompts or tools found on the server.', + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `Error discovering prompts from test-server: Test error`, ); - consoleErrorSpy.mockRestore(); }); }); - describe('appendMcpServerCommand', () => { it('should do nothing if no MCP servers or command are configured', () => { const out = populateMcpServerCommand({}, undefined); @@ -501,17 +211,6 @@ describe('mcp-client', () => { }); describe('createTransport', () => { - const originalEnv = process.env; - - beforeEach(() => { - vi.resetModules(); - process.env = {}; - }); - - afterEach(() => { - process.env = originalEnv; - }); - describe('should connect via httpUrl', () => { it('without headers', async () => { const transport = await createTransport( @@ -601,7 +300,7 @@ describe('mcp-client', () => { command: 'test-command', args: ['--foo', 'bar'], cwd: 'test/cwd', - env: { FOO: 'bar' }, + env: { ...process.env, FOO: 'bar' }, stderr: 'pipe', }); }); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index e9001466..ede0d036 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -69,6 +69,134 @@ export enum MCPDiscoveryState { COMPLETED = 'completed', } +/** + * A client for a single MCP server. + * + * This class is responsible for connecting to, discovering tools from, and + * managing the state of a single MCP server. + */ +export class McpClient { + private client: Client; + private transport: Transport | undefined; + private status: MCPServerStatus = MCPServerStatus.DISCONNECTED; + private isDisconnecting = false; + + constructor( + private readonly serverName: string, + private readonly serverConfig: MCPServerConfig, + private readonly toolRegistry: ToolRegistry, + private readonly promptRegistry: PromptRegistry, + private readonly workspaceContext: WorkspaceContext, + private readonly debugMode: boolean, + ) { + this.client = new Client({ + name: `gemini-cli-mcp-client-${this.serverName}`, + version: '0.0.1', + }); + } + + /** + * Connects to the MCP server. + */ + async connect(): Promise { + this.isDisconnecting = false; + this.updateStatus(MCPServerStatus.CONNECTING); + try { + this.transport = await this.createTransport(); + + this.client.onerror = (error) => { + if (this.isDisconnecting) { + return; + } + console.error(`MCP ERROR (${this.serverName}):`, error.toString()); + this.updateStatus(MCPServerStatus.DISCONNECTED); + }; + + this.client.registerCapabilities({ + roots: {}, + }); + + this.client.setRequestHandler(ListRootsRequestSchema, async () => { + const roots = []; + for (const dir of this.workspaceContext.getDirectories()) { + roots.push({ + uri: pathToFileURL(dir).toString(), + name: basename(dir), + }); + } + return { + roots, + }; + }); + + await this.client.connect(this.transport, { + timeout: this.serverConfig.timeout, + }); + + this.updateStatus(MCPServerStatus.CONNECTED); + } catch (error) { + this.updateStatus(MCPServerStatus.DISCONNECTED); + throw error; + } + } + + /** + * Discovers tools and prompts from the MCP server. + */ + async discover(): Promise { + if (this.status !== MCPServerStatus.CONNECTED) { + throw new Error('Client is not connected.'); + } + + const prompts = await this.discoverPrompts(); + const tools = await this.discoverTools(); + + if (prompts.length === 0 && tools.length === 0) { + throw new Error('No prompts or tools found on the server.'); + } + + for (const tool of tools) { + this.toolRegistry.registerTool(tool); + } + } + + /** + * Disconnects from the MCP server. + */ + async disconnect(): Promise { + this.isDisconnecting = true; + if (this.transport) { + await this.transport.close(); + } + this.client.close(); + this.updateStatus(MCPServerStatus.DISCONNECTED); + } + + /** + * Returns the current status of the client. + */ + getStatus(): MCPServerStatus { + return this.status; + } + + private updateStatus(status: MCPServerStatus): void { + this.status = status; + updateMCPServerStatus(this.serverName, status); + } + + private async createTransport(): Promise { + return createTransport(this.serverName, this.serverConfig, this.debugMode); + } + + private async discoverTools(): Promise { + return discoverTools(this.serverName, this.serverConfig, this.client); + } + + private async discoverPrompts(): Promise { + return discoverPrompts(this.serverName, this.client, this.promptRegistry); + } +} + /** * Map to track the status of each MCP server within the core package */ @@ -117,7 +245,7 @@ export function removeMCPStatusChangeListener( /** * Update the status of an MCP server */ -function updateMCPServerStatus( +export function updateMCPServerStatus( serverName: string, status: MCPServerStatus, ): void { diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 13dff08c..cccf011f 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -23,15 +23,17 @@ import { spawn } from 'node:child_process'; import fs from 'node:fs'; import { MockTool } from '../test-utils/tools.js'; +import { McpClientManager } from './mcp-client-manager.js'; + vi.mock('node:fs'); -// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory -const mockDiscoverMcpTools = vi.hoisted(() => vi.fn()); - // Mock ./mcp-client.js to control its behavior within tool-registry tests -vi.mock('./mcp-client.js', () => ({ - discoverMcpTools: mockDiscoverMcpTools, -})); +vi.mock('./mcp-client.js', async () => { + const originalModule = await vi.importActual('./mcp-client.js'); + return { + ...originalModule, + }; +}); // Mock node:child_process vi.mock('node:child_process', async () => { @@ -143,7 +145,6 @@ describe('ToolRegistry', () => { clear: vi.fn(), removePromptsByServer: vi.fn(), } as any); - mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined); }); afterEach(() => { @@ -311,6 +312,10 @@ describe('ToolRegistry', () => { }); it('should discover tools using MCP servers defined in getMcpServers', async () => { + const discoverSpy = vi.spyOn( + McpClientManager.prototype, + 'discoverAllMcpTools', + ); mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined); const mcpServerConfigVal = { @@ -324,38 +329,7 @@ describe('ToolRegistry', () => { await toolRegistry.discoverAllTools(); - expect(mockDiscoverMcpTools).toHaveBeenCalledWith( - mcpServerConfigVal, - undefined, - toolRegistry, - config.getPromptRegistry(), - false, - expect.any(Object), - ); - }); - - it('should discover tools using MCP servers defined in getMcpServers', async () => { - mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined); - vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined); - const mcpServerConfigVal = { - 'my-mcp-server': { - command: 'mcp-server-cmd', - args: ['--port', '1234'], - trust: true, - }, - }; - vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal); - - await toolRegistry.discoverAllTools(); - - expect(mockDiscoverMcpTools).toHaveBeenCalledWith( - mcpServerConfigVal, - undefined, - toolRegistry, - config.getPromptRegistry(), - false, - expect.any(Object), - ); + expect(discoverSpy).toHaveBeenCalled(); }); }); }); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index ff155679..90531742 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -16,7 +16,8 @@ import { import { Config } from '../config/config.js'; import { spawn } from 'node:child_process'; import { StringDecoder } from 'node:string_decoder'; -import { discoverMcpTools } from './mcp-client.js'; +import { connectAndDiscover } from './mcp-client.js'; +import { McpClientManager } from './mcp-client-manager.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import { parse } from 'shell-quote'; @@ -163,9 +164,18 @@ Signal: Signal number or \`(none)\` if no signal was received. export class ToolRegistry { private tools: Map = new Map(); private config: Config; + private mcpClientManager: McpClientManager; constructor(config: Config) { this.config = config; + this.mcpClientManager = new McpClientManager( + this.config.getMcpServers() ?? {}, + this.config.getMcpServerCommand(), + this, + this.config.getPromptRegistry(), + this.config.getDebugMode(), + this.config.getWorkspaceContext(), + ); } /** @@ -220,14 +230,7 @@ export class ToolRegistry { await this.discoverAndRegisterToolsFromCommand(); // discover tools using MCP servers, if configured - await discoverMcpTools( - this.config.getMcpServers() ?? {}, - this.config.getMcpServerCommand(), - this, - this.config.getPromptRegistry(), - this.config.getDebugMode(), - this.config.getWorkspaceContext(), - ); + await this.mcpClientManager.discoverAllMcpTools(); } /** @@ -242,14 +245,14 @@ export class ToolRegistry { this.config.getPromptRegistry().clear(); // discover tools using MCP servers, if configured - await discoverMcpTools( - this.config.getMcpServers() ?? {}, - this.config.getMcpServerCommand(), - this, - this.config.getPromptRegistry(), - this.config.getDebugMode(), - this.config.getWorkspaceContext(), - ); + await this.mcpClientManager.discoverAllMcpTools(); + } + + /** + * Restarts all MCP servers and re-discovers tools. + */ + async restartMcpServers(): Promise { + await this.discoverMcpTools(); } /** @@ -269,9 +272,9 @@ export class ToolRegistry { const mcpServers = this.config.getMcpServers() ?? {}; const serverConfig = mcpServers[serverName]; if (serverConfig) { - await discoverMcpTools( - { [serverName]: serverConfig }, - undefined, + await connectAndDiscover( + serverName, + serverConfig, this, this.config.getPromptRegistry(), this.config.getDebugMode(),