feat: add support to remote MCP servers for custom HTTP headers (#2477)

This commit is contained in:
Adam Spiers 2025-06-30 01:09:08 +01:00 committed by GitHub
parent d1eb86581c
commit 0fd602eb43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 149 additions and 2 deletions

View File

@ -87,6 +87,7 @@ Each server configuration supports the following properties:
#### Optional
- **`args`** (string[]): Command-line arguments for Stdio transport
- **`headers`** (object): Custom HTTP headers when using `httpUrl`
- **`env`** (object): Environment variables for the server process. Values can reference environment variables using `$VAR_NAME` or `${VAR_NAME}` syntax
- **`cwd`** (string): Working directory for Stdio transport
- **`timeout`** (number): Request timeout in milliseconds (default: 600,000ms = 10 minutes)
@ -166,6 +167,24 @@ Each server configuration supports the following properties:
}
```
#### HTTP-based MCP Server with Custom Headers
```json
{
"mcpServers": {
"httpServerWithAuth": {
"httpUrl": "http://localhost:3000/mcp",
"headers": {
"Authorization": "Bearer your-api-token",
"X-Custom-Header": "custom-value",
"Content-Type": "application/json"
},
"timeout": 5000
}
}
}
```
## Discovery Process Deep Dive
When the Gemini CLI starts, it performs MCP server discovery through the following detailed process:

View File

@ -76,6 +76,7 @@ export class MCPServerConfig {
readonly url?: string,
// For streamable http transport
readonly httpUrl?: string,
readonly headers?: Record<string, string>,
// For websocket transport
readonly tcp?: string,
// Common

View File

@ -21,6 +21,7 @@ import { DiscoveredMCPTool } from './mcp-tool.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { parse, ParseEntry } from 'shell-quote';
// Mock dependencies
@ -65,6 +66,16 @@ vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
return { SSEClientTransport: MockedSSETransport };
});
vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => {
const MockedStreamableHTTPTransport = vi.fn().mockImplementation(function (
this: any,
) {
this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
return this;
});
return { StreamableHTTPClientTransport: MockedStreamableHTTPTransport };
});
const mockToolRegistryInstance = {
registerTool: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
@ -129,6 +140,15 @@ describe('discoverMcpTools', () => {
this.close = vi.fn().mockResolvedValue(undefined);
return this;
});
vi.mocked(StreamableHTTPClientTransport).mockClear();
// Ensure the StreamableHTTPClientTransport mock constructor returns an object with a close method
vi.mocked(StreamableHTTPClientTransport).mockImplementation(function (
this: any,
) {
this.close = vi.fn().mockResolvedValue(undefined);
return this;
});
});
afterEach(() => {
@ -267,6 +287,100 @@ describe('discoverMcpTools', () => {
expect(registeredTool.name).toBe('tool-sse');
});
it('should discover tools via mcpServers config (streamable http)', async () => {
const serverConfig: MCPServerConfig = {
httpUrl: 'http://localhost:3000/mcp',
};
mockConfig.getMcpServers.mockReturnValue({ 'http-server': serverConfig });
const mockTool = {
name: 'tool-http',
description: 'desc-http',
inputSchema: { type: 'object' as const, properties: {} },
};
vi.mocked(Client.prototype.listTools).mockResolvedValue({
tools: [mockTool],
});
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
new URL(serverConfig.httpUrl!),
{},
);
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
expect.any(DiscoveredMCPTool),
);
const registeredTool = mockToolRegistry.registerTool.mock
.calls[0][0] as DiscoveredMCPTool;
expect(registeredTool.name).toBe('tool-http');
});
describe('StreamableHTTPClientTransport headers', () => {
const setupHttpTest = async (headers?: Record<string, string>) => {
const serverConfig: MCPServerConfig = {
httpUrl: 'http://localhost:3000/mcp',
...(headers && { headers }),
};
const serverName = headers
? 'http-server-with-headers'
: 'http-server-no-headers';
const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers';
mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig });
const mockTool = {
name: toolName,
description: `desc-${toolName}`,
inputSchema: { type: 'object' as const, properties: {} },
};
vi.mocked(Client.prototype.listTools).mockResolvedValue({
tools: [mockTool],
});
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
expect.any(DiscoveredMCPTool),
]);
await discoverMcpTools(
mockConfig.getMcpServers() ?? {},
mockConfig.getMcpServerCommand(),
mockToolRegistry as any,
);
return { serverConfig };
};
it('should pass headers when provided', async () => {
const headers = {
Authorization: 'Bearer test-token',
'X-Custom-Header': 'custom-value',
};
const { serverConfig } = await setupHttpTest(headers);
expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
new URL(serverConfig.httpUrl!),
{ requestInit: { headers } },
);
});
it('should work without headers (backwards compatibility)', async () => {
const { serverConfig } = await setupHttpTest();
expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
new URL(serverConfig.httpUrl!),
{},
);
});
});
it('should prefix tool names if multiple MCP servers are configured', async () => {
const serverConfig1: MCPServerConfig = { command: './mcp1' };
const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' };

View File

@ -7,7 +7,10 @@
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import {
StreamableHTTPClientTransport,
StreamableHTTPClientTransportOptions,
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { parse } from 'shell-quote';
import { MCPServerConfig } from '../config/config.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
@ -169,8 +172,17 @@ async function connectAndDiscover(
let transport;
if (mcpServerConfig.httpUrl) {
const transportOptions: StreamableHTTPClientTransportOptions = {};
if (mcpServerConfig.headers) {
transportOptions.requestInit = {
headers: mcpServerConfig.headers,
};
}
transport = new StreamableHTTPClientTransport(
new URL(mcpServerConfig.httpUrl),
transportOptions,
);
} else if (mcpServerConfig.url) {
transport = new SSEClientTransport(new URL(mcpServerConfig.url));
@ -222,10 +234,11 @@ async function connectAndDiscover(
const safeConfig = {
command: mcpServerConfig.command,
url: mcpServerConfig.url,
httpUrl: mcpServerConfig.httpUrl,
cwd: mcpServerConfig.cwd,
timeout: mcpServerConfig.timeout,
trust: mcpServerConfig.trust,
// Exclude args and env which may contain sensitive data
// Exclude args, env, and headers which may contain sensitive data
};
let errorString =