feat: add support to remote MCP servers for custom HTTP headers (#2477)
This commit is contained in:
parent
d1eb86581c
commit
0fd602eb43
|
@ -87,6 +87,7 @@ Each server configuration supports the following properties:
|
|||
#### Optional
|
||||
|
||||
- **`args`** (string[]): Command-line arguments for Stdio transport
|
||||
- **`headers`** (object): Custom HTTP headers when using `httpUrl`
|
||||
- **`env`** (object): Environment variables for the server process. Values can reference environment variables using `$VAR_NAME` or `${VAR_NAME}` syntax
|
||||
- **`cwd`** (string): Working directory for Stdio transport
|
||||
- **`timeout`** (number): Request timeout in milliseconds (default: 600,000ms = 10 minutes)
|
||||
|
@ -166,6 +167,24 @@ Each server configuration supports the following properties:
|
|||
}
|
||||
```
|
||||
|
||||
#### HTTP-based MCP Server with Custom Headers
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"httpServerWithAuth": {
|
||||
"httpUrl": "http://localhost:3000/mcp",
|
||||
"headers": {
|
||||
"Authorization": "Bearer your-api-token",
|
||||
"X-Custom-Header": "custom-value",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
"timeout": 5000
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Discovery Process Deep Dive
|
||||
|
||||
When the Gemini CLI starts, it performs MCP server discovery through the following detailed process:
|
||||
|
|
|
@ -76,6 +76,7 @@ export class MCPServerConfig {
|
|||
readonly url?: string,
|
||||
// For streamable http transport
|
||||
readonly httpUrl?: string,
|
||||
readonly headers?: Record<string, string>,
|
||||
// For websocket transport
|
||||
readonly tcp?: string,
|
||||
// Common
|
||||
|
|
|
@ -21,6 +21,7 @@ import { DiscoveredMCPTool } from './mcp-tool.js';
|
|||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { parse, ParseEntry } from 'shell-quote';
|
||||
|
||||
// Mock dependencies
|
||||
|
@ -65,6 +66,16 @@ vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
|
|||
return { SSEClientTransport: MockedSSETransport };
|
||||
});
|
||||
|
||||
vi.mock('@modelcontextprotocol/sdk/client/streamableHttp.js', () => {
|
||||
const MockedStreamableHTTPTransport = vi.fn().mockImplementation(function (
|
||||
this: any,
|
||||
) {
|
||||
this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
|
||||
return this;
|
||||
});
|
||||
return { StreamableHTTPClientTransport: MockedStreamableHTTPTransport };
|
||||
});
|
||||
|
||||
const mockToolRegistryInstance = {
|
||||
registerTool: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
|
||||
|
@ -129,6 +140,15 @@ describe('discoverMcpTools', () => {
|
|||
this.close = vi.fn().mockResolvedValue(undefined);
|
||||
return this;
|
||||
});
|
||||
|
||||
vi.mocked(StreamableHTTPClientTransport).mockClear();
|
||||
// Ensure the StreamableHTTPClientTransport mock constructor returns an object with a close method
|
||||
vi.mocked(StreamableHTTPClientTransport).mockImplementation(function (
|
||||
this: any,
|
||||
) {
|
||||
this.close = vi.fn().mockResolvedValue(undefined);
|
||||
return this;
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
@ -267,6 +287,100 @@ describe('discoverMcpTools', () => {
|
|||
expect(registeredTool.name).toBe('tool-sse');
|
||||
});
|
||||
|
||||
it('should discover tools via mcpServers config (streamable http)', async () => {
|
||||
const serverConfig: MCPServerConfig = {
|
||||
httpUrl: 'http://localhost:3000/mcp',
|
||||
};
|
||||
mockConfig.getMcpServers.mockReturnValue({ 'http-server': serverConfig });
|
||||
|
||||
const mockTool = {
|
||||
name: 'tool-http',
|
||||
description: 'desc-http',
|
||||
inputSchema: { type: 'object' as const, properties: {} },
|
||||
};
|
||||
vi.mocked(Client.prototype.listTools).mockResolvedValue({
|
||||
tools: [mockTool],
|
||||
});
|
||||
|
||||
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
||||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
|
||||
new URL(serverConfig.httpUrl!),
|
||||
{},
|
||||
);
|
||||
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
|
||||
expect.any(DiscoveredMCPTool),
|
||||
);
|
||||
const registeredTool = mockToolRegistry.registerTool.mock
|
||||
.calls[0][0] as DiscoveredMCPTool;
|
||||
expect(registeredTool.name).toBe('tool-http');
|
||||
});
|
||||
|
||||
describe('StreamableHTTPClientTransport headers', () => {
|
||||
const setupHttpTest = async (headers?: Record<string, string>) => {
|
||||
const serverConfig: MCPServerConfig = {
|
||||
httpUrl: 'http://localhost:3000/mcp',
|
||||
...(headers && { headers }),
|
||||
};
|
||||
const serverName = headers
|
||||
? 'http-server-with-headers'
|
||||
: 'http-server-no-headers';
|
||||
const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers';
|
||||
|
||||
mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig });
|
||||
|
||||
const mockTool = {
|
||||
name: toolName,
|
||||
description: `desc-${toolName}`,
|
||||
inputSchema: { type: 'object' as const, properties: {} },
|
||||
};
|
||||
vi.mocked(Client.prototype.listTools).mockResolvedValue({
|
||||
tools: [mockTool],
|
||||
});
|
||||
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
||||
expect.any(DiscoveredMCPTool),
|
||||
]);
|
||||
|
||||
await discoverMcpTools(
|
||||
mockConfig.getMcpServers() ?? {},
|
||||
mockConfig.getMcpServerCommand(),
|
||||
mockToolRegistry as any,
|
||||
);
|
||||
|
||||
return { serverConfig };
|
||||
};
|
||||
|
||||
it('should pass headers when provided', async () => {
|
||||
const headers = {
|
||||
Authorization: 'Bearer test-token',
|
||||
'X-Custom-Header': 'custom-value',
|
||||
};
|
||||
const { serverConfig } = await setupHttpTest(headers);
|
||||
|
||||
expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
|
||||
new URL(serverConfig.httpUrl!),
|
||||
{ requestInit: { headers } },
|
||||
);
|
||||
});
|
||||
|
||||
it('should work without headers (backwards compatibility)', async () => {
|
||||
const { serverConfig } = await setupHttpTest();
|
||||
|
||||
expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(
|
||||
new URL(serverConfig.httpUrl!),
|
||||
{},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should prefix tool names if multiple MCP servers are configured', async () => {
|
||||
const serverConfig1: MCPServerConfig = { command: './mcp1' };
|
||||
const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' };
|
||||
|
|
|
@ -7,7 +7,10 @@
|
|||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import {
|
||||
StreamableHTTPClientTransport,
|
||||
StreamableHTTPClientTransportOptions,
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import { MCPServerConfig } from '../config/config.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
@ -169,8 +172,17 @@ async function connectAndDiscover(
|
|||
|
||||
let transport;
|
||||
if (mcpServerConfig.httpUrl) {
|
||||
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
||||
|
||||
if (mcpServerConfig.headers) {
|
||||
transportOptions.requestInit = {
|
||||
headers: mcpServerConfig.headers,
|
||||
};
|
||||
}
|
||||
|
||||
transport = new StreamableHTTPClientTransport(
|
||||
new URL(mcpServerConfig.httpUrl),
|
||||
transportOptions,
|
||||
);
|
||||
} else if (mcpServerConfig.url) {
|
||||
transport = new SSEClientTransport(new URL(mcpServerConfig.url));
|
||||
|
@ -222,10 +234,11 @@ async function connectAndDiscover(
|
|||
const safeConfig = {
|
||||
command: mcpServerConfig.command,
|
||||
url: mcpServerConfig.url,
|
||||
httpUrl: mcpServerConfig.httpUrl,
|
||||
cwd: mcpServerConfig.cwd,
|
||||
timeout: mcpServerConfig.timeout,
|
||||
trust: mcpServerConfig.trust,
|
||||
// Exclude args and env which may contain sensitive data
|
||||
// Exclude args, env, and headers which may contain sensitive data
|
||||
};
|
||||
|
||||
let errorString =
|
||||
|
|
Loading…
Reference in New Issue