From f35921a77171d011d244cba1b2da0531f9749332 Mon Sep 17 00:00:00 2001 From: Jacob MacDonald Date: Fri, 8 Aug 2025 16:29:06 -0700 Subject: [PATCH] Add MCP Roots support (#5856) Co-authored-by: Jacob Richman --- packages/core/src/tools/mcp-client.test.ts | 57 ++++++++++++++++++- packages/core/src/tools/mcp-client.ts | 26 +++++++++ packages/core/src/tools/tool-registry.test.ts | 2 + packages/core/src/tools/tool-registry.ts | 3 + 4 files changed, 87 insertions(+), 1 deletion(-) diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 1ccba76a..d37c6eae 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -13,6 +13,7 @@ import { discoverTools, discoverPrompts, hasValidTypes, + connectToMcpServer, } from './mcp-client.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -23,6 +24,8 @@ import { AuthProviderType } from '../config/config.js'; import { PromptRegistry } from '../prompts/prompt-registry.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; +import { WorkspaceContext } from '../utils/workspaceContext.js'; +import { pathToFileURL } from 'node:url'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); @@ -276,6 +279,56 @@ describe('mcp-client', () => { }); }); + describe('connectToMcpServer', () => { + it('should register a roots/list handler', async () => { + const mockedClient = { + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + callTool: vi.fn(), + connect: vi.fn(), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + const mockWorkspaceContext = { + getDirectories: vi + .fn() + .mockReturnValue(['/test/dir', '/another/project']), + } as unknown as WorkspaceContext; + + await connectToMcpServer( + 'test-server', + { + command: 'test-command', + }, + false, + mockWorkspaceContext, + ); + + expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({ + roots: {}, + }); + expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce(); + const handler = mockedClient.setRequestHandler.mock.calls[0][1]; + const roots = await handler(); + expect(roots).toEqual({ + roots: [ + { + uri: pathToFileURL('/test/dir').toString(), + name: 'dir', + }, + { + uri: pathToFileURL('/another/project').toString(), + name: 'project', + }, + ], + }); + }); + }); + describe('discoverPrompts', () => { const mockedPromptRegistry = { registerPrompt: vi.fn(), @@ -486,7 +539,9 @@ describe('mcp-client', () => { }); it('should connect via command', async () => { - const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport); + const mockedTransport = vi + .spyOn(SdkClientStdioLib, 'StdioClientTransport') + .mockReturnValue({} as SdkClientStdioLib.StdioClientTransport); await createTransport( 'test-server', diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 9a35b84e..83bc4024 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -20,6 +20,7 @@ import { ListPromptsResultSchema, GetPromptResult, GetPromptResultSchema, + ListRootsRequestSchema, } from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; import { AuthProviderType, MCPServerConfig } from '../config/config.js'; @@ -33,6 +34,9 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { getErrorMessage } from '../utils/errors.js'; +import { basename } from 'node:path'; +import { pathToFileURL } from 'node:url'; +import { WorkspaceContext } from '../utils/workspaceContext.js'; export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes @@ -306,6 +310,7 @@ export async function discoverMcpTools( toolRegistry: ToolRegistry, promptRegistry: PromptRegistry, debugMode: boolean, + workspaceContext: WorkspaceContext, ): Promise { mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; try { @@ -319,6 +324,7 @@ export async function discoverMcpTools( toolRegistry, promptRegistry, debugMode, + workspaceContext, ), ); await Promise.all(discoveryPromises); @@ -363,6 +369,7 @@ export async function connectAndDiscover( toolRegistry: ToolRegistry, promptRegistry: PromptRegistry, debugMode: boolean, + workspaceContext: WorkspaceContext, ): Promise { updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); @@ -372,6 +379,7 @@ export async function connectAndDiscover( mcpServerName, mcpServerConfig, debugMode, + workspaceContext, ); mcpClient.onerror = (error) => { @@ -655,12 +663,30 @@ export async function connectToMcpServer( mcpServerName: string, mcpServerConfig: MCPServerConfig, debugMode: boolean, + workspaceContext: WorkspaceContext, ): Promise { const mcpClient = new Client({ name: 'gemini-cli-mcp-client', version: '0.0.1', }); + mcpClient.registerCapabilities({ + roots: {}, + }); + + mcpClient.setRequestHandler(ListRootsRequestSchema, async () => { + const roots = []; + for (const dir of workspaceContext.getDirectories()) { + roots.push({ + uri: pathToFileURL(dir).toString(), + name: basename(dir), + }); + } + return { + roots, + }; + }); + // patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it // TODO: remove this hack once GenAI SDK does callTool with request options if ('callTool' in mcpClient) { diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index e7c71e14..d8e536b7 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -336,6 +336,7 @@ describe('ToolRegistry', () => { toolRegistry, config.getPromptRegistry(), false, + expect.any(Object), ); }); @@ -359,6 +360,7 @@ describe('ToolRegistry', () => { toolRegistry, config.getPromptRegistry(), false, + expect.any(Object), ); }); }); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index c77fab8c..70226052 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -178,6 +178,7 @@ export class ToolRegistry { this, this.config.getPromptRegistry(), this.config.getDebugMode(), + this.config.getWorkspaceContext(), ); } @@ -199,6 +200,7 @@ export class ToolRegistry { this, this.config.getPromptRegistry(), this.config.getDebugMode(), + this.config.getWorkspaceContext(), ); } @@ -225,6 +227,7 @@ export class ToolRegistry { this, this.config.getPromptRegistry(), this.config.getDebugMode(), + this.config.getWorkspaceContext(), ); } }