/** * @license * Copyright 2025 Google LLC * SPDX-License-Identifier: Apache-2.0 */ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { SSEClientTransport, SSEClientTransportOptions, } from '@modelcontextprotocol/sdk/client/sse.js'; import { StreamableHTTPClientTransport, StreamableHTTPClientTransportOptions, } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { parse } from 'shell-quote'; import { MCPServerConfig } from '../config/config.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import { FunctionDeclaration, Type, mcpToTool } from '@google/genai'; import { sanitizeParameters, ToolRegistry } from './tool-registry.js'; import { ActiveFileNotificationSchema, IDE_SERVER_NAME, ideContext, } from '../services/ideContext.js'; export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes /** * Enum representing the connection status of an MCP server */ export enum MCPServerStatus { /** Server is disconnected or experiencing errors */ DISCONNECTED = 'disconnected', /** Server is in the process of connecting */ CONNECTING = 'connecting', /** Server is connected and ready to use */ CONNECTED = 'connected', } /** * Enum representing the overall MCP discovery state */ export enum MCPDiscoveryState { /** Discovery has not started yet */ NOT_STARTED = 'not_started', /** Discovery is currently in progress */ IN_PROGRESS = 'in_progress', /** Discovery has completed (with or without errors) */ COMPLETED = 'completed', } /** * Map to track the status of each MCP server within the core package */ const mcpServerStatusesInternal: Map = new Map(); /** * Track the overall MCP discovery state */ let mcpDiscoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; /** * Event listeners for MCP server status changes */ type StatusChangeListener = ( serverName: string, status: MCPServerStatus, ) => void; const statusChangeListeners: StatusChangeListener[] = []; /** * Add a listener for MCP server status changes */ export function addMCPStatusChangeListener( listener: StatusChangeListener, ): void { statusChangeListeners.push(listener); } /** * Remove a listener for MCP server status changes */ export function removeMCPStatusChangeListener( listener: StatusChangeListener, ): void { const index = statusChangeListeners.indexOf(listener); if (index !== -1) { statusChangeListeners.splice(index, 1); } } /** * Update the status of an MCP server */ function updateMCPServerStatus( serverName: string, status: MCPServerStatus, ): void { mcpServerStatusesInternal.set(serverName, status); // Notify all listeners for (const listener of statusChangeListeners) { listener(serverName, status); } } /** * Get the current status of an MCP server */ export function getMCPServerStatus(serverName: string): MCPServerStatus { return ( mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED ); } /** * Get all MCP server statuses */ export function getAllMCPServerStatuses(): Map { return new Map(mcpServerStatusesInternal); } /** * Get the current MCP discovery state */ export function getMCPDiscoveryState(): MCPDiscoveryState { return mcpDiscoveryState; } /** * Discovers tools from all configured MCP servers and registers them with the tool registry. * It orchestrates the connection and discovery process for each server defined in the * configuration, as well as any server specified via a command-line argument. * * @param mcpServers A record of named MCP server configurations. * @param mcpServerCommand An optional command string for a dynamically specified MCP server. * @param toolRegistry The central registry where discovered tools will be registered. * @returns A promise that resolves when the discovery process has been attempted for all servers. */ export async function discoverMcpTools( mcpServers: Record, mcpServerCommand: string | undefined, toolRegistry: ToolRegistry, debugMode: boolean, ): Promise { mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; try { mcpServers = populateMcpServerCommand(mcpServers, mcpServerCommand); const discoveryPromises = Object.entries(mcpServers).map( ([mcpServerName, mcpServerConfig]) => connectAndDiscover( mcpServerName, mcpServerConfig, toolRegistry, debugMode, ), ); await Promise.all(discoveryPromises); } finally { mcpDiscoveryState = MCPDiscoveryState.COMPLETED; } } /** Visible for Testing */ export function populateMcpServerCommand( mcpServers: Record, mcpServerCommand: string | undefined, ): Record { if (mcpServerCommand) { const cmd = mcpServerCommand; const args = parse(cmd, process.env) as string[]; if (args.some((arg) => typeof arg !== 'string')) { throw new Error('failed to parse mcpServerCommand: ' + cmd); } // use generic server name 'mcp' mcpServers['mcp'] = { command: args[0], args: args.slice(1), }; } return mcpServers; } /** * Connects to an MCP server and discovers available tools, registering them with the tool registry. * This function handles the complete lifecycle of connecting to a server, discovering tools, * and cleaning up resources if no tools are found. * * @param mcpServerName The name identifier for this MCP server * @param mcpServerConfig Configuration object containing connection details * @param toolRegistry The registry to register discovered tools with * @returns Promise that resolves when discovery is complete */ export async function connectAndDiscover( mcpServerName: string, mcpServerConfig: MCPServerConfig, toolRegistry: ToolRegistry, debugMode: boolean, ): Promise { updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); try { const mcpClient = await connectToMcpServer( mcpServerName, mcpServerConfig, debugMode, ); try { updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED); mcpClient.onerror = (error) => { console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); }; if (mcpServerName === IDE_SERVER_NAME) { mcpClient.setNotificationHandler( ActiveFileNotificationSchema, (notification) => { ideContext.setActiveFileContext(notification.params); }, ); } const tools = await discoverTools( mcpServerName, mcpServerConfig, mcpClient, ); for (const tool of tools) { toolRegistry.registerTool(tool); } } catch (error) { mcpClient.close(); throw error; } } catch (error) { console.error(`Error connecting to MCP server '${mcpServerName}':`, error); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); } } /** * Discovers and sanitizes tools from a connected MCP client. * It retrieves function declarations from the client, filters out disabled tools, * generates valid names for them, and wraps them in `DiscoveredMCPTool` instances. * * @param mcpServerName The name of the MCP server. * @param mcpServerConfig The configuration for the MCP server. * @param mcpClient The active MCP client instance. * @returns A promise that resolves to an array of discovered and enabled tools. * @throws An error if no enabled tools are found or if the server provides invalid function declarations. */ export async function discoverTools( mcpServerName: string, mcpServerConfig: MCPServerConfig, mcpClient: Client, ): Promise { try { const mcpCallableTool = mcpToTool(mcpClient); const tool = await mcpCallableTool.tool(); if (!Array.isArray(tool.functionDeclarations)) { throw new Error(`Server did not return valid function declarations.`); } const discoveredTools: DiscoveredMCPTool[] = []; for (const funcDecl of tool.functionDeclarations) { if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) { continue; } const toolNameForModel = generateValidName(funcDecl, mcpServerName); sanitizeParameters(funcDecl.parameters); discoveredTools.push( new DiscoveredMCPTool( mcpCallableTool, mcpServerName, toolNameForModel, funcDecl.description ?? '', funcDecl.parameters ?? { type: Type.OBJECT, properties: {} }, funcDecl.name!, mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, mcpServerConfig.trust, ), ); } if (discoveredTools.length === 0) { throw Error('No enabled tools found'); } return discoveredTools; } catch (error) { throw new Error(`Error discovering tools: ${error}`); } } /** * Creates and connects an MCP client to a server based on the provided configuration. * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and * establishes a connection. It also applies a patch to handle request timeouts. * * @param mcpServerName The name of the MCP server, used for logging and identification. * @param mcpServerConfig The configuration specifying how to connect to the server. * @returns A promise that resolves to a connected MCP `Client` instance. * @throws An error if the connection fails or the configuration is invalid. */ export async function connectToMcpServer( mcpServerName: string, mcpServerConfig: MCPServerConfig, debugMode: boolean, ): Promise { const mcpClient = new Client({ name: 'gemini-cli-mcp-client', version: '0.0.1', }); // patch Client.callTool to use request timeout as genai McpCallTool.callTool does not do it // TODO: remove this hack once GenAI SDK does callTool with request options if ('callTool' in mcpClient) { const origCallTool = mcpClient.callTool.bind(mcpClient); mcpClient.callTool = function (params, resultSchema, options) { return origCallTool(params, resultSchema, { ...options, timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, }); }; } try { const transport = createTransport( mcpServerName, mcpServerConfig, debugMode, ); try { await mcpClient.connect(transport, { timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, }); return mcpClient; } catch (error) { await transport.close(); throw error; } } catch (error) { // Create a safe config object that excludes sensitive information const safeConfig = { command: mcpServerConfig.command, url: mcpServerConfig.url, httpUrl: mcpServerConfig.httpUrl, cwd: mcpServerConfig.cwd, timeout: mcpServerConfig.timeout, trust: mcpServerConfig.trust, // Exclude args, env, and headers which may contain sensitive data }; let errorString = `failed to start or connect to MCP server '${mcpServerName}' ` + `${JSON.stringify(safeConfig)}; \n${error}`; if (process.env.SANDBOX) { errorString += `\nMake sure it is available in the sandbox`; } throw new Error(errorString); } } /** Visible for Testing */ export function createTransport( mcpServerName: string, mcpServerConfig: MCPServerConfig, debugMode: boolean, ): Transport { if (mcpServerConfig.httpUrl) { const transportOptions: StreamableHTTPClientTransportOptions = {}; if (mcpServerConfig.headers) { transportOptions.requestInit = { headers: mcpServerConfig.headers, }; } return new StreamableHTTPClientTransport( new URL(mcpServerConfig.httpUrl), transportOptions, ); } if (mcpServerConfig.url) { const transportOptions: SSEClientTransportOptions = {}; if (mcpServerConfig.headers) { transportOptions.requestInit = { headers: mcpServerConfig.headers, }; } return new SSEClientTransport( new URL(mcpServerConfig.url), transportOptions, ); } if (mcpServerConfig.command) { const transport = new StdioClientTransport({ command: mcpServerConfig.command, args: mcpServerConfig.args || [], env: { ...process.env, ...(mcpServerConfig.env || {}), } as Record, cwd: mcpServerConfig.cwd, stderr: 'pipe', }); if (debugMode) { transport.stderr!.on('data', (data) => { const stderrStr = data.toString().trim(); console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr); }); } return transport; } throw new Error( `Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`, ); } /** Visible for testing */ export function generateValidName( funcDecl: FunctionDeclaration, mcpServerName: string, ) { // Replace invalid characters (based on 400 error message from Gemini API) with underscores let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_'); // Prepend MCP server name to avoid conflicts with other tools validToolname = mcpServerName + '__' + validToolname; // If longer than 63 characters, replace middle with '___' // (Gemini API says max length 64, but actual limit seems to be 63) if (validToolname.length > 63) { validToolname = validToolname.slice(0, 28) + '___' + validToolname.slice(-32); } return validToolname; } /** Visible for testing */ export function isEnabled( funcDecl: FunctionDeclaration, mcpServerName: string, mcpServerConfig: MCPServerConfig, ): boolean { if (!funcDecl.name) { console.warn( `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`, ); return false; } const { includeTools, excludeTools } = mcpServerConfig; // excludeTools takes precedence over includeTools if (excludeTools && excludeTools.includes(funcDecl.name)) { return false; } return ( !includeTools || includeTools.some( (tool) => tool === funcDecl.name || tool.startsWith(`${funcDecl.name}(`), ) ); }