From 2e57989aec569055a11f21762f72b961377281ab Mon Sep 17 00:00:00 2001 From: Olcan Date: Fri, 30 May 2025 15:32:21 -0700 Subject: [PATCH] confirm mcp tool executions from untrusted servers (per "trust" setting) (#631) --- .../messages/ToolConfirmationMessage.tsx | 30 +++++++++++- packages/server/src/config/config.ts | 1 + packages/server/src/tools/mcp-client.ts | 2 + packages/server/src/tools/mcp-tool.test.ts | 5 ++ packages/server/src/tools/mcp-tool.ts | 47 ++++++++++++++++++- .../server/src/tools/tool-registry.test.ts | 2 + packages/server/src/tools/tools.ts | 20 ++++++-- 7 files changed, 101 insertions(+), 6 deletions(-) diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx index 0606856f..65030309 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx @@ -12,6 +12,7 @@ import { ToolCallConfirmationDetails, ToolConfirmationOutcome, ToolExecuteConfirmationDetails, + ToolMcpConfirmationDetails, } from '@gemini-code/server'; import { RadioButtonSelect, @@ -64,7 +65,7 @@ export const ToolConfirmationMessage: React.FC< }, { label: 'No (esc)', value: ToolConfirmationOutcome.Cancel }, ); - } else { + } else if (confirmationDetails.type === 'exec') { const executionProps = confirmationDetails as ToolExecuteConfirmationDetails; @@ -88,6 +89,33 @@ export const ToolConfirmationMessage: React.FC< }, { label: 'No (esc)', value: ToolConfirmationOutcome.Cancel }, ); + } else { + // mcp tool confirmation + const mcpProps = confirmationDetails as ToolMcpConfirmationDetails; + + bodyContent = ( + + MCP Server: {mcpProps.serverName} + Tool: {mcpProps.toolName} + + ); + + question = `Allow execution of MCP tool "${mcpProps.toolName}" from server "${mcpProps.serverName}"?`; + options.push( + { + label: 'Yes, allow once', + value: ToolConfirmationOutcome.ProceedOnce, + }, + { + label: `Yes, always allow tool "${mcpProps.toolName}" from server "${mcpProps.serverName}"`, + value: ToolConfirmationOutcome.ProceedAlwaysTool, // Cast until types are updated + }, + { + label: `Yes, always allow all tools from server "${mcpProps.serverName}"`, + value: ToolConfirmationOutcome.ProceedAlwaysServer, + }, + { label: 'No (esc)', value: ToolConfirmationOutcome.Cancel }, + ); } return ( diff --git a/packages/server/src/config/config.ts b/packages/server/src/config/config.ts index 9c03a5c1..0cd7a4fa 100644 --- a/packages/server/src/config/config.ts +++ b/packages/server/src/config/config.ts @@ -33,6 +33,7 @@ export class MCPServerConfig { readonly url?: string, // Common readonly timeout?: number, + readonly trust?: boolean, ) {} } diff --git a/packages/server/src/tools/mcp-client.ts b/packages/server/src/tools/mcp-client.ts index 3b55f5e3..97a73289 100644 --- a/packages/server/src/tools/mcp-client.ts +++ b/packages/server/src/tools/mcp-client.ts @@ -134,11 +134,13 @@ async function connectAndDiscover( toolRegistry.registerTool( new DiscoveredMCPTool( mcpClient, + mcpServerName, toolNameForModel, tool.description ?? '', tool.inputSchema, tool.name, mcpServerConfig.timeout, + mcpServerConfig.trust, ), ); } diff --git a/packages/server/src/tools/mcp-tool.test.ts b/packages/server/src/tools/mcp-tool.test.ts index e28cf586..331696f7 100644 --- a/packages/server/src/tools/mcp-tool.test.ts +++ b/packages/server/src/tools/mcp-tool.test.ts @@ -55,6 +55,7 @@ describe('DiscoveredMCPTool', () => { it('should set properties correctly and augment description', () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, @@ -78,6 +79,7 @@ describe('DiscoveredMCPTool', () => { const customTimeout = 5000; const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, @@ -92,6 +94,7 @@ describe('DiscoveredMCPTool', () => { it('should call mcpClient.callTool with correct parameters and default timeout', async () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, @@ -122,6 +125,7 @@ describe('DiscoveredMCPTool', () => { const customTimeout = 15000; const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, @@ -146,6 +150,7 @@ describe('DiscoveredMCPTool', () => { it('should propagate rejection if mcpClient.callTool rejects', async () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, baseDescription, inputSchema, diff --git a/packages/server/src/tools/mcp-tool.ts b/packages/server/src/tools/mcp-tool.ts index 2a561179..80e6bbde 100644 --- a/packages/server/src/tools/mcp-tool.ts +++ b/packages/server/src/tools/mcp-tool.ts @@ -5,20 +5,30 @@ */ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { BaseTool, ToolResult } from './tools.js'; +import { + BaseTool, + ToolResult, + ToolCallConfirmationDetails, + ToolConfirmationOutcome, + ToolMcpConfirmationDetails, +} from './tools.js'; type ToolParams = Record; export const MCP_TOOL_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes export class DiscoveredMCPTool extends BaseTool { + private static readonly whitelist: Set = new Set(); + constructor( private readonly mcpClient: Client, + private readonly serverName: string, // Added for server identification readonly name: string, readonly description: string, readonly parameterSchema: Record, readonly serverToolName: string, readonly timeout?: number, + readonly trust?: boolean, ) { description += ` @@ -37,6 +47,41 @@ Returns the MCP server response as a json string. ); } + async shouldConfirmExecute( + _params: ToolParams, + _abortSignal: AbortSignal, + ): Promise { + const serverWhitelistKey = this.serverName; + const toolWhitelistKey = `${this.serverName}.${this.serverToolName}`; + + if (this.trust) { + return false; // server is trusted, no confirmation needed + } + + if ( + DiscoveredMCPTool.whitelist.has(serverWhitelistKey) || + DiscoveredMCPTool.whitelist.has(toolWhitelistKey) + ) { + return false; // server and/or tool already whitelisted + } + + const confirmationDetails: ToolMcpConfirmationDetails = { + type: 'mcp', + title: 'Confirm MCP Tool Execution', + serverName: this.serverName, + toolName: this.serverToolName, + toolDisplayName: this.name, + onConfirm: async (outcome: ToolConfirmationOutcome) => { + if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) { + DiscoveredMCPTool.whitelist.add(serverWhitelistKey); + } else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) { + DiscoveredMCPTool.whitelist.add(toolWhitelistKey); + } + }, + }; + return confirmationDetails; + } + async execute(params: ToolParams): Promise { const result = await this.mcpClient.callTool( { diff --git a/packages/server/src/tools/tool-registry.test.ts b/packages/server/src/tools/tool-registry.test.ts index 6a960a27..c93109ae 100644 --- a/packages/server/src/tools/tool-registry.test.ts +++ b/packages/server/src/tools/tool-registry.test.ts @@ -729,6 +729,7 @@ describe('DiscoveredMCPTool', () => { it('constructor should set up properties correctly and enhance description', () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, toolDescription, toolInputSchema, @@ -744,6 +745,7 @@ describe('DiscoveredMCPTool', () => { it('execute should call mcpClient.callTool with correct params and return serialized result', async () => { const tool = new DiscoveredMCPTool( mockMcpClient, + 'mock-mcp-server', toolName, toolDescription, toolInputSchema, diff --git a/packages/server/src/tools/tools.ts b/packages/server/src/tools/tools.ts index e5d0c7cf..a2e7fa06 100644 --- a/packages/server/src/tools/tools.ts +++ b/packages/server/src/tools/tools.ts @@ -212,12 +212,24 @@ export interface ToolExecuteConfirmationDetails { rootCommand: string; } +export interface ToolMcpConfirmationDetails { + type: 'mcp'; + title: string; + serverName: string; + toolName: string; + toolDisplayName: string; + onConfirm: (outcome: ToolConfirmationOutcome) => Promise | void; +} + export type ToolCallConfirmationDetails = | ToolEditConfirmationDetails - | ToolExecuteConfirmationDetails; + | ToolExecuteConfirmationDetails + | ToolMcpConfirmationDetails; export enum ToolConfirmationOutcome { - ProceedOnce, - ProceedAlways, - Cancel, + ProceedOnce = 'proceed_once', + ProceedAlways = 'proceed_always', + ProceedAlwaysServer = 'proceed_always_server', + ProceedAlwaysTool = 'proceed_always_tool', + Cancel = 'cancel', }