Add MCP Roots support (#5856)

Co-authored-by: Jacob Richman <jacob314@gmail.com>
This commit is contained in:
Jacob MacDonald 2025-08-08 16:29:06 -07:00 committed by GitHub
parent c03ae43777
commit f35921a771
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 87 additions and 1 deletions

View File

@ -13,6 +13,7 @@ import {
discoverTools, discoverTools,
discoverPrompts, discoverPrompts,
hasValidTypes, hasValidTypes,
connectToMcpServer,
} from './mcp-client.js'; } from './mcp-client.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@ -23,6 +24,8 @@ import { AuthProviderType } from '../config/config.js';
import { PromptRegistry } from '../prompts/prompt-registry.js'; import { PromptRegistry } from '../prompts/prompt-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js'; import { DiscoveredMCPTool } from './mcp-tool.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import { pathToFileURL } from 'node:url';
vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
vi.mock('@modelcontextprotocol/sdk/client/index.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js');
@ -276,6 +279,56 @@ describe('mcp-client', () => {
}); });
}); });
describe('connectToMcpServer', () => {
it('should register a roots/list handler', async () => {
const mockedClient = {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
callTool: vi.fn(),
connect: vi.fn(),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockWorkspaceContext = {
getDirectories: vi
.fn()
.mockReturnValue(['/test/dir', '/another/project']),
} as unknown as WorkspaceContext;
await connectToMcpServer(
'test-server',
{
command: 'test-command',
},
false,
mockWorkspaceContext,
);
expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
roots: {},
});
expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
const handler = mockedClient.setRequestHandler.mock.calls[0][1];
const roots = await handler();
expect(roots).toEqual({
roots: [
{
uri: pathToFileURL('/test/dir').toString(),
name: 'dir',
},
{
uri: pathToFileURL('/another/project').toString(),
name: 'project',
},
],
});
});
});
describe('discoverPrompts', () => { describe('discoverPrompts', () => {
const mockedPromptRegistry = { const mockedPromptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
@ -486,7 +539,9 @@ describe('mcp-client', () => {
}); });
it('should connect via command', async () => { it('should connect via command', async () => {
const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport); const mockedTransport = vi
.spyOn(SdkClientStdioLib, 'StdioClientTransport')
.mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
await createTransport( await createTransport(
'test-server', 'test-server',

View File

@ -20,6 +20,7 @@ import {
ListPromptsResultSchema, ListPromptsResultSchema,
GetPromptResult, GetPromptResult,
GetPromptResultSchema, GetPromptResultSchema,
ListRootsRequestSchema,
} from '@modelcontextprotocol/sdk/types.js'; } from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote'; import { parse } from 'shell-quote';
import { AuthProviderType, MCPServerConfig } from '../config/config.js'; import { AuthProviderType, MCPServerConfig } from '../config/config.js';
@ -33,6 +34,9 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { OAuthUtils } from '../mcp/oauth-utils.js'; import { OAuthUtils } from '../mcp/oauth-utils.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
import { basename } from 'node:path';
import { pathToFileURL } from 'node:url';
import { WorkspaceContext } from '../utils/workspaceContext.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
@ -306,6 +310,7 @@ export async function discoverMcpTools(
toolRegistry: ToolRegistry, toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry, promptRegistry: PromptRegistry,
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext,
): Promise<void> { ): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
try { try {
@ -319,6 +324,7 @@ export async function discoverMcpTools(
toolRegistry, toolRegistry,
promptRegistry, promptRegistry,
debugMode, debugMode,
workspaceContext,
), ),
); );
await Promise.all(discoveryPromises); await Promise.all(discoveryPromises);
@ -363,6 +369,7 @@ export async function connectAndDiscover(
toolRegistry: ToolRegistry, toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry, promptRegistry: PromptRegistry,
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext,
): Promise<void> { ): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@ -372,6 +379,7 @@ export async function connectAndDiscover(
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
debugMode, debugMode,
workspaceContext,
); );
mcpClient.onerror = (error) => { mcpClient.onerror = (error) => {
@ -655,12 +663,30 @@ export async function connectToMcpServer(
mcpServerName: string, mcpServerName: string,
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext,
): Promise<Client> { ): Promise<Client> {
const mcpClient = new Client({ const mcpClient = new Client({
name: 'gemini-cli-mcp-client', name: 'gemini-cli-mcp-client',
version: '0.0.1', version: '0.0.1',
}); });
mcpClient.registerCapabilities({
roots: {},
});
mcpClient.setRequestHandler(ListRootsRequestSchema, async () => {
const roots = [];
for (const dir of workspaceContext.getDirectories()) {
roots.push({
uri: pathToFileURL(dir).toString(),
name: basename(dir),
});
}
return {
roots,
};
});
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it // patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
// TODO: remove this hack once GenAI SDK does callTool with request options // TODO: remove this hack once GenAI SDK does callTool with request options
if ('callTool' in mcpClient) { if ('callTool' in mcpClient) {

View File

@ -336,6 +336,7 @@ describe('ToolRegistry', () => {
toolRegistry, toolRegistry,
config.getPromptRegistry(), config.getPromptRegistry(),
false, false,
expect.any(Object),
); );
}); });
@ -359,6 +360,7 @@ describe('ToolRegistry', () => {
toolRegistry, toolRegistry,
config.getPromptRegistry(), config.getPromptRegistry(),
false, false,
expect.any(Object),
); );
}); });
}); });

View File

@ -178,6 +178,7 @@ export class ToolRegistry {
this, this,
this.config.getPromptRegistry(), this.config.getPromptRegistry(),
this.config.getDebugMode(), this.config.getDebugMode(),
this.config.getWorkspaceContext(),
); );
} }
@ -199,6 +200,7 @@ export class ToolRegistry {
this, this,
this.config.getPromptRegistry(), this.config.getPromptRegistry(),
this.config.getDebugMode(), this.config.getDebugMode(),
this.config.getWorkspaceContext(),
); );
} }
@ -225,6 +227,7 @@ export class ToolRegistry {
this, this,
this.config.getPromptRegistry(), this.config.getPromptRegistry(),
this.config.getDebugMode(), this.config.getDebugMode(),
this.config.getWorkspaceContext(),
); );
} }
} }