Add MCP Roots support (#5856)
Co-authored-by: Jacob Richman <jacob314@gmail.com>
This commit is contained in:
parent
c03ae43777
commit
f35921a771
|
@ -13,6 +13,7 @@ import {
|
||||||
discoverTools,
|
discoverTools,
|
||||||
discoverPrompts,
|
discoverPrompts,
|
||||||
hasValidTypes,
|
hasValidTypes,
|
||||||
|
connectToMcpServer,
|
||||||
} from './mcp-client.js';
|
} from './mcp-client.js';
|
||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||||
|
@ -23,6 +24,8 @@ import { AuthProviderType } from '../config/config.js';
|
||||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
|
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
|
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||||
|
import { pathToFileURL } from 'node:url';
|
||||||
|
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||||
|
@ -276,6 +279,56 @@ describe('mcp-client', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('connectToMcpServer', () => {
|
||||||
|
it('should register a roots/list handler', async () => {
|
||||||
|
const mockedClient = {
|
||||||
|
registerCapabilities: vi.fn(),
|
||||||
|
setRequestHandler: vi.fn(),
|
||||||
|
callTool: vi.fn(),
|
||||||
|
connect: vi.fn(),
|
||||||
|
};
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
const mockWorkspaceContext = {
|
||||||
|
getDirectories: vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue(['/test/dir', '/another/project']),
|
||||||
|
} as unknown as WorkspaceContext;
|
||||||
|
|
||||||
|
await connectToMcpServer(
|
||||||
|
'test-server',
|
||||||
|
{
|
||||||
|
command: 'test-command',
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
mockWorkspaceContext,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
|
||||||
|
roots: {},
|
||||||
|
});
|
||||||
|
expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
|
||||||
|
const handler = mockedClient.setRequestHandler.mock.calls[0][1];
|
||||||
|
const roots = await handler();
|
||||||
|
expect(roots).toEqual({
|
||||||
|
roots: [
|
||||||
|
{
|
||||||
|
uri: pathToFileURL('/test/dir').toString(),
|
||||||
|
name: 'dir',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
uri: pathToFileURL('/another/project').toString(),
|
||||||
|
name: 'project',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('discoverPrompts', () => {
|
describe('discoverPrompts', () => {
|
||||||
const mockedPromptRegistry = {
|
const mockedPromptRegistry = {
|
||||||
registerPrompt: vi.fn(),
|
registerPrompt: vi.fn(),
|
||||||
|
@ -486,7 +539,9 @@ describe('mcp-client', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should connect via command', async () => {
|
it('should connect via command', async () => {
|
||||||
const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
|
const mockedTransport = vi
|
||||||
|
.spyOn(SdkClientStdioLib, 'StdioClientTransport')
|
||||||
|
.mockReturnValue({} as SdkClientStdioLib.StdioClientTransport);
|
||||||
|
|
||||||
await createTransport(
|
await createTransport(
|
||||||
'test-server',
|
'test-server',
|
||||||
|
|
|
@ -20,6 +20,7 @@ import {
|
||||||
ListPromptsResultSchema,
|
ListPromptsResultSchema,
|
||||||
GetPromptResult,
|
GetPromptResult,
|
||||||
GetPromptResultSchema,
|
GetPromptResultSchema,
|
||||||
|
ListRootsRequestSchema,
|
||||||
} from '@modelcontextprotocol/sdk/types.js';
|
} from '@modelcontextprotocol/sdk/types.js';
|
||||||
import { parse } from 'shell-quote';
|
import { parse } from 'shell-quote';
|
||||||
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
|
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
|
||||||
|
@ -33,6 +34,9 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
|
import { basename } from 'node:path';
|
||||||
|
import { pathToFileURL } from 'node:url';
|
||||||
|
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||||
|
|
||||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||||
|
|
||||||
|
@ -306,6 +310,7 @@ export async function discoverMcpTools(
|
||||||
toolRegistry: ToolRegistry,
|
toolRegistry: ToolRegistry,
|
||||||
promptRegistry: PromptRegistry,
|
promptRegistry: PromptRegistry,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
|
workspaceContext: WorkspaceContext,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||||
try {
|
try {
|
||||||
|
@ -319,6 +324,7 @@ export async function discoverMcpTools(
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
promptRegistry,
|
promptRegistry,
|
||||||
debugMode,
|
debugMode,
|
||||||
|
workspaceContext,
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
await Promise.all(discoveryPromises);
|
await Promise.all(discoveryPromises);
|
||||||
|
@ -363,6 +369,7 @@ export async function connectAndDiscover(
|
||||||
toolRegistry: ToolRegistry,
|
toolRegistry: ToolRegistry,
|
||||||
promptRegistry: PromptRegistry,
|
promptRegistry: PromptRegistry,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
|
workspaceContext: WorkspaceContext,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||||
|
|
||||||
|
@ -372,6 +379,7 @@ export async function connectAndDiscover(
|
||||||
mcpServerName,
|
mcpServerName,
|
||||||
mcpServerConfig,
|
mcpServerConfig,
|
||||||
debugMode,
|
debugMode,
|
||||||
|
workspaceContext,
|
||||||
);
|
);
|
||||||
|
|
||||||
mcpClient.onerror = (error) => {
|
mcpClient.onerror = (error) => {
|
||||||
|
@ -655,12 +663,30 @@ export async function connectToMcpServer(
|
||||||
mcpServerName: string,
|
mcpServerName: string,
|
||||||
mcpServerConfig: MCPServerConfig,
|
mcpServerConfig: MCPServerConfig,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
|
workspaceContext: WorkspaceContext,
|
||||||
): Promise<Client> {
|
): Promise<Client> {
|
||||||
const mcpClient = new Client({
|
const mcpClient = new Client({
|
||||||
name: 'gemini-cli-mcp-client',
|
name: 'gemini-cli-mcp-client',
|
||||||
version: '0.0.1',
|
version: '0.0.1',
|
||||||
});
|
});
|
||||||
|
|
||||||
|
mcpClient.registerCapabilities({
|
||||||
|
roots: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
mcpClient.setRequestHandler(ListRootsRequestSchema, async () => {
|
||||||
|
const roots = [];
|
||||||
|
for (const dir of workspaceContext.getDirectories()) {
|
||||||
|
roots.push({
|
||||||
|
uri: pathToFileURL(dir).toString(),
|
||||||
|
name: basename(dir),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
roots,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
|
// patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it
|
||||||
// TODO: remove this hack once GenAI SDK does callTool with request options
|
// TODO: remove this hack once GenAI SDK does callTool with request options
|
||||||
if ('callTool' in mcpClient) {
|
if ('callTool' in mcpClient) {
|
||||||
|
|
|
@ -336,6 +336,7 @@ describe('ToolRegistry', () => {
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
config.getPromptRegistry(),
|
config.getPromptRegistry(),
|
||||||
false,
|
false,
|
||||||
|
expect.any(Object),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -359,6 +360,7 @@ describe('ToolRegistry', () => {
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
config.getPromptRegistry(),
|
config.getPromptRegistry(),
|
||||||
false,
|
false,
|
||||||
|
expect.any(Object),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -178,6 +178,7 @@ export class ToolRegistry {
|
||||||
this,
|
this,
|
||||||
this.config.getPromptRegistry(),
|
this.config.getPromptRegistry(),
|
||||||
this.config.getDebugMode(),
|
this.config.getDebugMode(),
|
||||||
|
this.config.getWorkspaceContext(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -199,6 +200,7 @@ export class ToolRegistry {
|
||||||
this,
|
this,
|
||||||
this.config.getPromptRegistry(),
|
this.config.getPromptRegistry(),
|
||||||
this.config.getDebugMode(),
|
this.config.getDebugMode(),
|
||||||
|
this.config.getWorkspaceContext(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,6 +227,7 @@ export class ToolRegistry {
|
||||||
this,
|
this,
|
||||||
this.config.getPromptRegistry(),
|
this.config.getPromptRegistry(),
|
||||||
this.config.getDebugMode(),
|
this.config.getDebugMode(),
|
||||||
|
this.config.getWorkspaceContext(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue