From 0fd602eb43eea7abca980dc2ae3fd7bf2ba76a2a Mon Sep 17 00:00:00 2001 From: Adam Spiers Date: Mon, 30 Jun 2025 01:09:08 +0100 Subject: [PATCH] feat: add support to remote MCP servers for custom HTTP headers (#2477) --- docs/tools/mcp-server.md | 19 ++++ packages/core/src/config/config.ts | 1 + packages/core/src/tools/mcp-client.test.ts | 114 +++++++++++++++++++++ packages/core/src/tools/mcp-client.ts | 17 ++- 4 files changed, 149 insertions(+), 2 deletions(-) diff --git a/docs/tools/mcp-server.md b/docs/tools/mcp-server.md index ebce6160..0be9a34b 100644 --- a/docs/tools/mcp-server.md +++ b/docs/tools/mcp-server.md @@ -87,6 +87,7 @@ Each server configuration supports the following properties: #### Optional - **`args`** (string[]): Command-line arguments for Stdio transport +- **`headers`** (object): Custom HTTP headers when using `httpUrl` - **`env`** (object): Environment variables for the server process. Values can reference environment variables using `$VAR_NAME` or `${VAR_NAME}` syntax - **`cwd`** (string): Working directory for Stdio transport - **`timeout`** (number): Request timeout in milliseconds (default: 600,000ms = 10 minutes) @@ -166,6 +167,24 @@ Each server configuration supports the following properties: } ``` +#### HTTP-based MCP Server with Custom Headers + +```json +{ + "mcpServers": { + "httpServerWithAuth": { + "httpUrl": "http://localhost:3000/mcp", + "headers": { + "Authorization": "Bearer your-api-token", + "X-Custom-Header": "custom-value", + "Content-Type": "application/json" + }, + "timeout": 5000 + } + } +} +``` + ## Discovery Process Deep Dive When the Gemini CLI starts, it performs MCP server discovery through the following detailed process: diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 4ee2d23f..3bb5b85e 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -76,6 +76,7 @@ export class MCPServerConfig { readonly url?: string, // For streamable http transport readonly httpUrl?: string, + readonly headers?: Record, // For websocket transport readonly tcp?: string, // Common diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index f963a060..91524a2f 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -21,6 +21,7 @@ import { DiscoveredMCPTool } from './mcp-tool.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; +import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { parse, ParseEntry } from 'shell-quote'; // Mock dependencies @@ -65,6 +66,16 @@ vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => { return { SSEClientTransport: MockedSSETransport }; }); +vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => { + const MockedStreamableHTTPTransport = vi.fn().mockImplementation(function ( + this: any, + ) { + this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method + return this; + }); + return { StreamableHTTPClientTransport: MockedStreamableHTTPTransport }; +}); + const mockToolRegistryInstance = { registerTool: vi.fn(), getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array @@ -129,6 +140,15 @@ describe('discoverMcpTools', () => { this.close = vi.fn().mockResolvedValue(undefined); return this; }); + + vi.mocked(StreamableHTTPClientTransport).mockClear(); + // Ensure the StreamableHTTPClientTransport mock constructor returns an object with a close method + vi.mocked(StreamableHTTPClientTransport).mockImplementation(function ( + this: any, + ) { + this.close = vi.fn().mockResolvedValue(undefined); + return this; + }); }); afterEach(() => { @@ -267,6 +287,100 @@ describe('discoverMcpTools', () => { expect(registeredTool.name).toBe('tool-sse'); }); + it('should discover tools via mcpServers config (streamable http)', async () => { + const serverConfig: MCPServerConfig = { + httpUrl: 'http://localhost:3000/mcp', + }; + mockConfig.getMcpServers.mockReturnValue({ 'http-server': serverConfig }); + + const mockTool = { + name: 'tool-http', + description: 'desc-http', + inputSchema: { type: 'object' as const, properties: {} }, + }; + vi.mocked(Client.prototype.listTools).mockResolvedValue({ + tools: [mockTool], + }); + + mockToolRegistry.getToolsByServer.mockReturnValueOnce([ + expect.any(DiscoveredMCPTool), + ]); + + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); + + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( + new URL(serverConfig.httpUrl!), + {}, + ); + expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( + expect.any(DiscoveredMCPTool), + ); + const registeredTool = mockToolRegistry.registerTool.mock + .calls[0][0] as DiscoveredMCPTool; + expect(registeredTool.name).toBe('tool-http'); + }); + + describe('StreamableHTTPClientTransport headers', () => { + const setupHttpTest = async (headers?: Record) => { + const serverConfig: MCPServerConfig = { + httpUrl: 'http://localhost:3000/mcp', + ...(headers && { headers }), + }; + const serverName = headers + ? 'http-server-with-headers' + : 'http-server-no-headers'; + const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers'; + + mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig }); + + const mockTool = { + name: toolName, + description: `desc-${toolName}`, + inputSchema: { type: 'object' as const, properties: {} }, + }; + vi.mocked(Client.prototype.listTools).mockResolvedValue({ + tools: [mockTool], + }); + mockToolRegistry.getToolsByServer.mockReturnValueOnce([ + expect.any(DiscoveredMCPTool), + ]); + + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); + + return { serverConfig }; + }; + + it('should pass headers when provided', async () => { + const headers = { + Authorization: 'Bearer test-token', + 'X-Custom-Header': 'custom-value', + }; + const { serverConfig } = await setupHttpTest(headers); + + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( + new URL(serverConfig.httpUrl!), + { requestInit: { headers } }, + ); + }); + + it('should work without headers (backwards compatibility)', async () => { + const { serverConfig } = await setupHttpTest(); + + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith( + new URL(serverConfig.httpUrl!), + {}, + ); + }); + }); + it('should prefix tool names if multiple MCP servers are configured', async () => { const serverConfig1: MCPServerConfig = { command: './mcp1' }; const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' }; diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 72382ac1..52196b80 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -7,7 +7,10 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; -import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { + StreamableHTTPClientTransport, + StreamableHTTPClientTransportOptions, +} from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { parse } from 'shell-quote'; import { MCPServerConfig } from '../config/config.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; @@ -169,8 +172,17 @@ async function connectAndDiscover( let transport; if (mcpServerConfig.httpUrl) { + const transportOptions: StreamableHTTPClientTransportOptions = {}; + + if (mcpServerConfig.headers) { + transportOptions.requestInit = { + headers: mcpServerConfig.headers, + }; + } + transport = new StreamableHTTPClientTransport( new URL(mcpServerConfig.httpUrl), + transportOptions, ); } else if (mcpServerConfig.url) { transport = new SSEClientTransport(new URL(mcpServerConfig.url)); @@ -222,10 +234,11 @@ async function connectAndDiscover( const safeConfig = { command: mcpServerConfig.command, url: mcpServerConfig.url, + httpUrl: mcpServerConfig.httpUrl, cwd: mcpServerConfig.cwd, timeout: mcpServerConfig.timeout, trust: mcpServerConfig.trust, - // Exclude args and env which may contain sensitive data + // Exclude args, env, and headers which may contain sensitive data }; let errorString =