diff --git a/docs/tools/mcp-server.md b/docs/tools/mcp-server.md index 0b26f89a..f764d75a 100644 --- a/docs/tools/mcp-server.md +++ b/docs/tools/mcp-server.md @@ -87,7 +87,7 @@ Each server configuration supports the following properties: #### Optional - **`args`** (string[]): Command-line arguments for Stdio transport -- **`headers`** (object): Custom HTTP headers when using `httpUrl` +- **`headers`** (object): Custom HTTP headers when using `url` or `httpUrl` - **`env`** (object): Environment variables for the server process. Values can reference environment variables using `$VAR_NAME` or `${VAR_NAME}` syntax - **`cwd`** (string): Working directory for Stdio transport - **`timeout`** (number): Request timeout in milliseconds (default: 600,000ms = 10 minutes) diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 1413a4f8..a70a0db2 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -284,7 +284,10 @@ describe('discoverMcpTools', () => { mockToolRegistry as any, ); - expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!)); + expect(SSEClientTransport).toHaveBeenCalledWith( + new URL(serverConfig.url!), + {}, + ); expect(mockToolRegistry.registerTool).toHaveBeenCalledWith( expect.any(DiscoveredMCPTool), ); @@ -293,6 +296,75 @@ describe('discoverMcpTools', () => { expect(registeredTool.name).toBe('tool-sse'); }); + describe('SseClientTransport headers', () => { + const setupSseTest = async (headers?: Record) => { + const serverConfig: MCPServerConfig = { + url: 'http://localhost:1234/sse', + ...(headers && { headers }), + }; + const serverName = headers + ? 'sse-server-with-headers' + : 'sse-server-no-headers'; + const toolName = headers ? 'tool-http-headers' : 'tool-http-no-headers'; + + mockConfig.getMcpServers.mockReturnValue({ [serverName]: serverConfig }); + + const mockTool = { + name: toolName, + description: `desc-${toolName}`, + inputSchema: { type: 'object' as const, properties: {} }, + }; + vi.mocked(Client.prototype.listTools).mockResolvedValue({ + tools: [mockTool], + }); + mockToolRegistry.getToolsByServer.mockReturnValueOnce([ + expect.any(DiscoveredMCPTool), + ]); + + await discoverMcpTools( + mockConfig.getMcpServers() ?? {}, + mockConfig.getMcpServerCommand(), + mockToolRegistry as any, + ); + + return { serverConfig }; + }; + + it('should pass headers when provided', async () => { + const headers = { + Authorization: 'Bearer test-token', + 'X-Custom-Header': 'custom-value', + }; + const { serverConfig } = await setupSseTest(headers); + + expect(SSEClientTransport).toHaveBeenCalledWith( + new URL(serverConfig.url!), + { requestInit: { headers } }, + ); + }); + + it('should work without headers (backwards compatibility)', async () => { + const { serverConfig } = await setupSseTest(); + + expect(SSEClientTransport).toHaveBeenCalledWith( + new URL(serverConfig.url!), + {}, + ); + }); + + it('should pass oauth token when provided', async () => { + const headers = { + Authorization: 'Bearer test-token', + }; + const { serverConfig } = await setupSseTest(headers); + + expect(SSEClientTransport).toHaveBeenCalledWith( + new URL(serverConfig.url!), + { requestInit: { headers } }, + ); + }); + }); + it('should discover tools via mcpServers config (streamable http)', async () => { const serverConfig: MCPServerConfig = { httpUrl: 'http://localhost:3000/mcp', diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 89e97963..e4a87b68 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -6,7 +6,10 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; -import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; +import { + SSEClientTransport, + SSEClientTransportOptions, +} from '@modelcontextprotocol/sdk/client/sse.js'; import { StreamableHTTPClientTransport, StreamableHTTPClientTransportOptions, @@ -190,7 +193,16 @@ async function connectAndDiscover( transportOptions, ); } else if (mcpServerConfig.url) { - transport = new SSEClientTransport(new URL(mcpServerConfig.url)); + const transportOptions: SSEClientTransportOptions = {}; + if (mcpServerConfig.headers) { + transportOptions.requestInit = { + headers: mcpServerConfig.headers, + }; + } + transport = new SSEClientTransport( + new URL(mcpServerConfig.url), + transportOptions, + ); } else if (mcpServerConfig.command) { transport = new StdioClientTransport({ command: mcpServerConfig.command,