MCP OAuth Part 2 - MCP Client Integration (#4318)
Co-authored-by: Greg Shikhman <shikhman@google.com>
This commit is contained in:
parent
138ff73821
commit
258c848909
|
@ -45,6 +45,10 @@ import {
|
||||||
} from './models.js';
|
} from './models.js';
|
||||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||||
import { shouldAttemptBrowserLaunch } from '../utils/browser.js';
|
import { shouldAttemptBrowserLaunch } from '../utils/browser.js';
|
||||||
|
import { MCPOAuthConfig } from '../mcp/oauth-provider.js';
|
||||||
|
|
||||||
|
// Re-export OAuth config type
|
||||||
|
export type { MCPOAuthConfig };
|
||||||
|
|
||||||
export enum ApprovalMode {
|
export enum ApprovalMode {
|
||||||
DEFAULT = 'default',
|
DEFAULT = 'default',
|
||||||
|
@ -112,6 +116,8 @@ export class MCPServerConfig {
|
||||||
readonly includeTools?: string[],
|
readonly includeTools?: string[],
|
||||||
readonly excludeTools?: string[],
|
readonly excludeTools?: string[],
|
||||||
readonly extensionName?: string,
|
readonly extensionName?: string,
|
||||||
|
// OAuth configuration
|
||||||
|
readonly oauth?: MCPOAuthConfig,
|
||||||
) {}
|
) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -158,6 +158,13 @@ export class GeminiClient {
|
||||||
this.getChat().setHistory(history);
|
this.getChat().setHistory(history);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async setTools(): Promise<void> {
|
||||||
|
const toolRegistry = await this.config.getToolRegistry();
|
||||||
|
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||||
|
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
||||||
|
this.getChat().setTools(tools);
|
||||||
|
}
|
||||||
|
|
||||||
async resetChat(): Promise<void> {
|
async resetChat(): Promise<void> {
|
||||||
this.chat = await this.startChat();
|
this.chat = await this.startChat();
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import {
|
||||||
createUserContent,
|
createUserContent,
|
||||||
Part,
|
Part,
|
||||||
GenerateContentResponseUsageMetadata,
|
GenerateContentResponseUsageMetadata,
|
||||||
|
Tool,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { retryWithBackoff } from '../utils/retry.js';
|
import { retryWithBackoff } from '../utils/retry.js';
|
||||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||||
|
@ -498,6 +499,10 @@ export class GeminiChat {
|
||||||
this.history = history;
|
this.history = history;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setTools(tools: Tool[]): void {
|
||||||
|
this.generationConfig.tools = tools;
|
||||||
|
}
|
||||||
|
|
||||||
getFinalUsageMetadata(
|
getFinalUsageMetadata(
|
||||||
chunks: GenerateContentResponse[],
|
chunks: GenerateContentResponse[],
|
||||||
): GenerateContentResponseUsageMetadata | undefined {
|
): GenerateContentResponseUsageMetadata | undefined {
|
||||||
|
|
|
@ -20,6 +20,8 @@ import * as GenAiLib from '@google/genai';
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||||
vi.mock('@google/genai');
|
vi.mock('@google/genai');
|
||||||
|
vi.mock('../mcp/oauth-provider.js');
|
||||||
|
vi.mock('../mcp/oauth-token-storage.js');
|
||||||
|
|
||||||
describe('mcp-client', () => {
|
describe('mcp-client', () => {
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
@ -82,7 +84,7 @@ describe('mcp-client', () => {
|
||||||
|
|
||||||
describe('should connect via httpUrl', () => {
|
describe('should connect via httpUrl', () => {
|
||||||
it('without headers', async () => {
|
it('without headers', async () => {
|
||||||
const transport = createTransport(
|
const transport = await createTransport(
|
||||||
'test-server',
|
'test-server',
|
||||||
{
|
{
|
||||||
httpUrl: 'http://test-server',
|
httpUrl: 'http://test-server',
|
||||||
|
@ -96,7 +98,7 @@ describe('mcp-client', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('with headers', async () => {
|
it('with headers', async () => {
|
||||||
const transport = createTransport(
|
const transport = await createTransport(
|
||||||
'test-server',
|
'test-server',
|
||||||
{
|
{
|
||||||
httpUrl: 'http://test-server',
|
httpUrl: 'http://test-server',
|
||||||
|
@ -117,7 +119,7 @@ describe('mcp-client', () => {
|
||||||
|
|
||||||
describe('should connect via url', () => {
|
describe('should connect via url', () => {
|
||||||
it('without headers', async () => {
|
it('without headers', async () => {
|
||||||
const transport = createTransport(
|
const transport = await createTransport(
|
||||||
'test-server',
|
'test-server',
|
||||||
{
|
{
|
||||||
url: 'http://test-server',
|
url: 'http://test-server',
|
||||||
|
@ -130,7 +132,7 @@ describe('mcp-client', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('with headers', async () => {
|
it('with headers', async () => {
|
||||||
const transport = createTransport(
|
const transport = await createTransport(
|
||||||
'test-server',
|
'test-server',
|
||||||
{
|
{
|
||||||
url: 'http://test-server',
|
url: 'http://test-server',
|
||||||
|
@ -149,10 +151,10 @@ describe('mcp-client', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should connect via command', () => {
|
it('should connect via command', async () => {
|
||||||
const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
|
const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport);
|
||||||
|
|
||||||
createTransport(
|
await createTransport(
|
||||||
'test-server',
|
'test-server',
|
||||||
{
|
{
|
||||||
command: 'test-command',
|
command: 'test-command',
|
||||||
|
|
|
@ -18,9 +18,11 @@ import {
|
||||||
import { parse } from 'shell-quote';
|
import { parse } from 'shell-quote';
|
||||||
import { MCPServerConfig } from '../config/config.js';
|
import { MCPServerConfig } from '../config/config.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
|
|
||||||
import { FunctionDeclaration, mcpToTool } from '@google/genai';
|
import { FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||||
import { ToolRegistry } from './tool-registry.js';
|
import { ToolRegistry } from './tool-registry.js';
|
||||||
|
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||||
|
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||||
|
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||||
import {
|
import {
|
||||||
OpenFilesNotificationSchema,
|
OpenFilesNotificationSchema,
|
||||||
IDE_SERVER_NAME,
|
IDE_SERVER_NAME,
|
||||||
|
@ -64,6 +66,11 @@ const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
|
||||||
*/
|
*/
|
||||||
let mcpDiscoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
let mcpDiscoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Map to track which MCP servers have been discovered to require OAuth
|
||||||
|
*/
|
||||||
|
export const mcpServerRequiresOAuth: Map<string, boolean> = new Map();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Event listeners for MCP server status changes
|
* Event listeners for MCP server status changes
|
||||||
*/
|
*/
|
||||||
|
@ -131,6 +138,165 @@ export function getMCPDiscoveryState(): MCPDiscoveryState {
|
||||||
return mcpDiscoveryState;
|
return mcpDiscoveryState;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Parse www-authenticate header to extract OAuth metadata URI.
|
||||||
|
*
|
||||||
|
* @param wwwAuthenticate The www-authenticate header value
|
||||||
|
* @returns The resource metadata URI if found, null otherwise
|
||||||
|
*/
|
||||||
|
function _parseWWWAuthenticate(wwwAuthenticate: string): string | null {
|
||||||
|
// Parse header like: Bearer realm="MCP Server", resource_metadata_uri="https://..."
|
||||||
|
const resourceMetadataMatch = wwwAuthenticate.match(
|
||||||
|
/resource_metadata_uri="([^"]+)"/,
|
||||||
|
);
|
||||||
|
return resourceMetadataMatch ? resourceMetadataMatch[1] : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract WWW-Authenticate header from error message string.
|
||||||
|
* This is a more robust approach than regex matching.
|
||||||
|
*
|
||||||
|
* @param errorString The error message string
|
||||||
|
* @returns The www-authenticate header value if found, null otherwise
|
||||||
|
*/
|
||||||
|
function extractWWWAuthenticateHeader(errorString: string): string | null {
|
||||||
|
// Try multiple patterns to extract the header
|
||||||
|
const patterns = [
|
||||||
|
/www-authenticate:\s*([^\n\r]+)/i,
|
||||||
|
/WWW-Authenticate:\s*([^\n\r]+)/i,
|
||||||
|
/"www-authenticate":\s*"([^"]+)"/i,
|
||||||
|
/'www-authenticate':\s*'([^']+)'/i,
|
||||||
|
];
|
||||||
|
|
||||||
|
for (const pattern of patterns) {
|
||||||
|
const match = errorString.match(pattern);
|
||||||
|
if (match) {
|
||||||
|
return match[1].trim();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle automatic OAuth discovery and authentication for a server.
|
||||||
|
*
|
||||||
|
* @param mcpServerName The name of the MCP server
|
||||||
|
* @param mcpServerConfig The MCP server configuration
|
||||||
|
* @param wwwAuthenticate The www-authenticate header value
|
||||||
|
* @returns True if OAuth was successfully configured and authenticated, false otherwise
|
||||||
|
*/
|
||||||
|
async function handleAutomaticOAuth(
|
||||||
|
mcpServerName: string,
|
||||||
|
mcpServerConfig: MCPServerConfig,
|
||||||
|
wwwAuthenticate: string,
|
||||||
|
): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
console.log(`🔐 '${mcpServerName}' requires OAuth authentication`);
|
||||||
|
|
||||||
|
// Always try to parse the resource metadata URI from the www-authenticate header
|
||||||
|
let oauthConfig;
|
||||||
|
const resourceMetadataUri =
|
||||||
|
OAuthUtils.parseWWWAuthenticateHeader(wwwAuthenticate);
|
||||||
|
if (resourceMetadataUri) {
|
||||||
|
oauthConfig = await OAuthUtils.discoverOAuthConfig(resourceMetadataUri);
|
||||||
|
} else if (mcpServerConfig.url) {
|
||||||
|
// Fallback: try to discover OAuth config from the base URL for SSE
|
||||||
|
const sseUrl = new URL(mcpServerConfig.url);
|
||||||
|
const baseUrl = `${sseUrl.protocol}//${sseUrl.host}`;
|
||||||
|
oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
|
||||||
|
} else if (mcpServerConfig.httpUrl) {
|
||||||
|
// Fallback: try to discover OAuth config from the base URL for HTTP
|
||||||
|
const httpUrl = new URL(mcpServerConfig.httpUrl);
|
||||||
|
const baseUrl = `${httpUrl.protocol}//${httpUrl.host}`;
|
||||||
|
oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!oauthConfig) {
|
||||||
|
console.error(
|
||||||
|
`❌ Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth configuration discovered - proceed with authentication
|
||||||
|
|
||||||
|
// Create OAuth configuration for authentication
|
||||||
|
const oauthAuthConfig = {
|
||||||
|
enabled: true,
|
||||||
|
authorizationUrl: oauthConfig.authorizationUrl,
|
||||||
|
tokenUrl: oauthConfig.tokenUrl,
|
||||||
|
scopes: oauthConfig.scopes || [],
|
||||||
|
};
|
||||||
|
|
||||||
|
// Perform OAuth authentication
|
||||||
|
console.log(
|
||||||
|
`Starting OAuth authentication for server '${mcpServerName}'...`,
|
||||||
|
);
|
||||||
|
await MCPOAuthProvider.authenticate(mcpServerName, oauthAuthConfig);
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`OAuth authentication successful for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
console.error(
|
||||||
|
`Failed to handle automatic OAuth for server '${mcpServerName}': ${getErrorMessage(error)}`,
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a transport with OAuth token for the given server configuration.
|
||||||
|
*
|
||||||
|
* @param mcpServerName The name of the MCP server
|
||||||
|
* @param mcpServerConfig The MCP server configuration
|
||||||
|
* @param accessToken The OAuth access token
|
||||||
|
* @returns The transport with OAuth token, or null if creation fails
|
||||||
|
*/
|
||||||
|
async function createTransportWithOAuth(
|
||||||
|
mcpServerName: string,
|
||||||
|
mcpServerConfig: MCPServerConfig,
|
||||||
|
accessToken: string,
|
||||||
|
): Promise<StreamableHTTPClientTransport | SSEClientTransport | null> {
|
||||||
|
try {
|
||||||
|
if (mcpServerConfig.httpUrl) {
|
||||||
|
// Create HTTP transport with OAuth token
|
||||||
|
const oauthTransportOptions: StreamableHTTPClientTransportOptions = {
|
||||||
|
requestInit: {
|
||||||
|
headers: {
|
||||||
|
...mcpServerConfig.headers,
|
||||||
|
Authorization: `Bearer ${accessToken}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
return new StreamableHTTPClientTransport(
|
||||||
|
new URL(mcpServerConfig.httpUrl),
|
||||||
|
oauthTransportOptions,
|
||||||
|
);
|
||||||
|
} else if (mcpServerConfig.url) {
|
||||||
|
// Create SSE transport with OAuth token in Authorization header
|
||||||
|
return new SSEClientTransport(new URL(mcpServerConfig.url), {
|
||||||
|
requestInit: {
|
||||||
|
headers: {
|
||||||
|
...mcpServerConfig.headers,
|
||||||
|
Authorization: `Bearer ${accessToken}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
} catch (error) {
|
||||||
|
console.error(
|
||||||
|
`Failed to create OAuth transport for server '${mcpServerName}': ${getErrorMessage(error)}`,
|
||||||
|
);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Discovers tools from all configured MCP servers and registers them with the tool registry.
|
* Discovers tools from all configured MCP servers and registers them with the tool registry.
|
||||||
* It orchestrates the connection and discovery process for each server defined in the
|
* It orchestrates the connection and discovery process for each server defined in the
|
||||||
|
@ -334,7 +500,7 @@ export async function connectToMcpServer(
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const transport = createTransport(
|
const transport = await createTransport(
|
||||||
mcpServerName,
|
mcpServerName,
|
||||||
mcpServerConfig,
|
mcpServerConfig,
|
||||||
debugMode,
|
debugMode,
|
||||||
|
@ -349,40 +515,396 @@ export async function connectToMcpServer(
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// Create a safe config object that excludes sensitive information
|
// Check if this is a 401 error that might indicate OAuth is required
|
||||||
const safeConfig = {
|
const errorString = String(error);
|
||||||
command: mcpServerConfig.command,
|
if (
|
||||||
url: mcpServerConfig.url,
|
errorString.includes('401') &&
|
||||||
httpUrl: mcpServerConfig.httpUrl,
|
(mcpServerConfig.httpUrl || mcpServerConfig.url)
|
||||||
cwd: mcpServerConfig.cwd,
|
) {
|
||||||
timeout: mcpServerConfig.timeout,
|
mcpServerRequiresOAuth.set(mcpServerName, true);
|
||||||
trust: mcpServerConfig.trust,
|
// Only trigger automatic OAuth discovery for HTTP servers or when OAuth is explicitly configured
|
||||||
// Exclude args, env, and headers which may contain sensitive data
|
// For SSE servers, we should not trigger new OAuth flows automatically
|
||||||
|
const shouldTriggerOAuth =
|
||||||
|
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
|
||||||
|
|
||||||
|
if (!shouldTriggerOAuth) {
|
||||||
|
// For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately.
|
||||||
|
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||||
|
if (credentials) {
|
||||||
|
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
|
||||||
|
mcpServerName,
|
||||||
|
{
|
||||||
|
// Pass client ID if available
|
||||||
|
clientId: credentials.clientId,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (hasStoredTokens) {
|
||||||
|
console.log(
|
||||||
|
`Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` +
|
||||||
|
`Please re-authenticate using: /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
console.log(
|
||||||
|
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||||
|
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new Error(
|
||||||
|
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||||
|
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to extract www-authenticate header from the error
|
||||||
|
let wwwAuthenticate = extractWWWAuthenticateHeader(errorString);
|
||||||
|
|
||||||
|
// If we didn't get the header from the error string, try to get it from the server
|
||||||
|
if (!wwwAuthenticate && mcpServerConfig.url) {
|
||||||
|
console.log(
|
||||||
|
`No www-authenticate header in error, trying to fetch it from server...`,
|
||||||
|
);
|
||||||
|
try {
|
||||||
|
const response = await fetch(mcpServerConfig.url, {
|
||||||
|
method: 'HEAD',
|
||||||
|
headers: {
|
||||||
|
Accept: 'text/event-stream',
|
||||||
|
},
|
||||||
|
signal: AbortSignal.timeout(5000),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (response.status === 401) {
|
||||||
|
wwwAuthenticate = response.headers.get('www-authenticate');
|
||||||
|
if (wwwAuthenticate) {
|
||||||
|
console.log(
|
||||||
|
`Found www-authenticate header from server: ${wwwAuthenticate}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (fetchError) {
|
||||||
|
console.debug(
|
||||||
|
`Failed to fetch www-authenticate header: ${getErrorMessage(fetchError)}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (wwwAuthenticate) {
|
||||||
|
console.log(
|
||||||
|
`Received 401 with www-authenticate header: ${wwwAuthenticate}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Try automatic OAuth discovery and authentication
|
||||||
|
const oauthSuccess = await handleAutomaticOAuth(
|
||||||
|
mcpServerName,
|
||||||
|
mcpServerConfig,
|
||||||
|
wwwAuthenticate,
|
||||||
|
);
|
||||||
|
if (oauthSuccess) {
|
||||||
|
// Retry connection with OAuth token
|
||||||
|
console.log(
|
||||||
|
`Retrying connection to '${mcpServerName}' with OAuth token...`,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get the valid token - we need to create a proper OAuth config
|
||||||
|
// The token should already be available from the authentication process
|
||||||
|
const credentials =
|
||||||
|
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||||
|
if (credentials) {
|
||||||
|
const accessToken = await MCPOAuthProvider.getValidToken(
|
||||||
|
mcpServerName,
|
||||||
|
{
|
||||||
|
// Pass client ID if available
|
||||||
|
clientId: credentials.clientId,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
if (accessToken) {
|
||||||
|
// Create transport with OAuth token
|
||||||
|
const oauthTransport = await createTransportWithOAuth(
|
||||||
|
mcpServerName,
|
||||||
|
mcpServerConfig,
|
||||||
|
accessToken,
|
||||||
|
);
|
||||||
|
if (oauthTransport) {
|
||||||
|
try {
|
||||||
|
await mcpClient.connect(oauthTransport, {
|
||||||
|
timeout:
|
||||||
|
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||||
|
});
|
||||||
|
// Connection successful with OAuth
|
||||||
|
return mcpClient;
|
||||||
|
} catch (retryError) {
|
||||||
|
console.error(
|
||||||
|
`Failed to connect with OAuth token: ${getErrorMessage(
|
||||||
|
retryError,
|
||||||
|
)}`,
|
||||||
|
);
|
||||||
|
throw retryError;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No www-authenticate header found, but we got a 401
|
||||||
|
// Only try OAuth discovery for HTTP servers or when OAuth is explicitly configured
|
||||||
|
// For SSE servers, we should not trigger new OAuth flows automatically
|
||||||
|
const shouldTryDiscovery =
|
||||||
|
mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled;
|
||||||
|
|
||||||
|
if (!shouldTryDiscovery) {
|
||||||
|
const credentials =
|
||||||
|
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||||
|
if (credentials) {
|
||||||
|
const hasStoredTokens = await MCPOAuthProvider.getValidToken(
|
||||||
|
mcpServerName,
|
||||||
|
{
|
||||||
|
// Pass client ID if available
|
||||||
|
clientId: credentials.clientId,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (hasStoredTokens) {
|
||||||
|
console.log(
|
||||||
|
`Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` +
|
||||||
|
`Please re-authenticate using: /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
console.log(
|
||||||
|
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||||
|
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new Error(
|
||||||
|
`401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` +
|
||||||
|
`Please authenticate using: /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// For SSE servers, try to discover OAuth configuration from the base URL
|
||||||
|
console.log(`🔍 Attempting OAuth discovery for '${mcpServerName}'...`);
|
||||||
|
|
||||||
|
if (mcpServerConfig.url) {
|
||||||
|
const sseUrl = new URL(mcpServerConfig.url);
|
||||||
|
const baseUrl = `${sseUrl.protocol}//${sseUrl.host}`;
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Try to discover OAuth configuration from the base URL
|
||||||
|
const oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl);
|
||||||
|
if (oauthConfig) {
|
||||||
|
console.log(
|
||||||
|
`Discovered OAuth configuration from base URL for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create OAuth configuration for authentication
|
||||||
|
const oauthAuthConfig = {
|
||||||
|
enabled: true,
|
||||||
|
authorizationUrl: oauthConfig.authorizationUrl,
|
||||||
|
tokenUrl: oauthConfig.tokenUrl,
|
||||||
|
scopes: oauthConfig.scopes || [],
|
||||||
};
|
};
|
||||||
|
|
||||||
let errorString =
|
// Perform OAuth authentication
|
||||||
`failed to start or connect to MCP server '${mcpServerName}' ` +
|
console.log(
|
||||||
`${JSON.stringify(safeConfig)}; \n${error}`;
|
`Starting OAuth authentication for server '${mcpServerName}'...`,
|
||||||
if (process.env.SANDBOX) {
|
);
|
||||||
errorString += `\nMake sure it is available in the sandbox`;
|
await MCPOAuthProvider.authenticate(
|
||||||
|
mcpServerName,
|
||||||
|
oauthAuthConfig,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Retry connection with OAuth token
|
||||||
|
const credentials =
|
||||||
|
await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||||
|
if (credentials) {
|
||||||
|
const accessToken = await MCPOAuthProvider.getValidToken(
|
||||||
|
mcpServerName,
|
||||||
|
{
|
||||||
|
// Pass client ID if available
|
||||||
|
clientId: credentials.clientId,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (accessToken) {
|
||||||
|
// Create transport with OAuth token
|
||||||
|
const oauthTransport = await createTransportWithOAuth(
|
||||||
|
mcpServerName,
|
||||||
|
mcpServerConfig,
|
||||||
|
accessToken,
|
||||||
|
);
|
||||||
|
if (oauthTransport) {
|
||||||
|
try {
|
||||||
|
await mcpClient.connect(oauthTransport, {
|
||||||
|
timeout:
|
||||||
|
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||||
|
});
|
||||||
|
// Connection successful with OAuth
|
||||||
|
return mcpClient;
|
||||||
|
} catch (retryError) {
|
||||||
|
console.error(
|
||||||
|
`Failed to connect with OAuth token: ${getErrorMessage(
|
||||||
|
retryError,
|
||||||
|
)}`,
|
||||||
|
);
|
||||||
|
throw retryError;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`Failed to create OAuth transport for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`Failed to get OAuth token for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`Failed to get stored credentials for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`Failed to get stored credentials for server '${mcpServerName}'`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`❌ Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch (discoveryError) {
|
||||||
|
console.error(
|
||||||
|
`❌ OAuth discovery failed for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||||
|
);
|
||||||
|
throw discoveryError;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.error(
|
||||||
|
`❌ '${mcpServerName}' requires authentication but no OAuth configuration found`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`MCP server '${mcpServerName}' requires authentication. Please configure OAuth or check server settings.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Handle other connection errors
|
||||||
|
// Create a concise error message
|
||||||
|
const errorMessage = (error as Error).message || String(error);
|
||||||
|
const isNetworkError =
|
||||||
|
errorMessage.includes('ENOTFOUND') ||
|
||||||
|
errorMessage.includes('ECONNREFUSED');
|
||||||
|
|
||||||
|
let conciseError: string;
|
||||||
|
if (isNetworkError) {
|
||||||
|
conciseError = `Cannot connect to '${mcpServerName}' - server may be down or URL incorrect`;
|
||||||
|
} else {
|
||||||
|
conciseError = `Connection failed for '${mcpServerName}': ${errorMessage}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (process.env.SANDBOX) {
|
||||||
|
conciseError += ` (check sandbox availability)`;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(conciseError);
|
||||||
}
|
}
|
||||||
throw new Error(errorString);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Visible for Testing */
|
/** Visible for Testing */
|
||||||
export function createTransport(
|
export async function createTransport(
|
||||||
mcpServerName: string,
|
mcpServerName: string,
|
||||||
mcpServerConfig: MCPServerConfig,
|
mcpServerConfig: MCPServerConfig,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
): Transport {
|
): Promise<Transport> {
|
||||||
|
// Check if we have OAuth configuration or stored tokens
|
||||||
|
let accessToken: string | null = null;
|
||||||
|
let hasOAuthConfig = mcpServerConfig.oauth?.enabled;
|
||||||
|
|
||||||
|
if (hasOAuthConfig && mcpServerConfig.oauth) {
|
||||||
|
accessToken = await MCPOAuthProvider.getValidToken(
|
||||||
|
mcpServerName,
|
||||||
|
mcpServerConfig.oauth,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!accessToken) {
|
||||||
|
console.error(
|
||||||
|
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
||||||
|
`Please authenticate using the /mcp auth command.`,
|
||||||
|
);
|
||||||
|
throw new Error(
|
||||||
|
`MCP server '${mcpServerName}' requires OAuth authentication. ` +
|
||||||
|
`Please authenticate using the /mcp auth command.`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Check if we have stored OAuth tokens for this server (from previous authentication)
|
||||||
|
const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName);
|
||||||
|
if (credentials) {
|
||||||
|
accessToken = await MCPOAuthProvider.getValidToken(mcpServerName, {
|
||||||
|
// Pass client ID if available
|
||||||
|
clientId: credentials.clientId,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (accessToken) {
|
||||||
|
hasOAuthConfig = true;
|
||||||
|
console.log(`Found stored OAuth token for server '${mcpServerName}'`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (mcpServerConfig.httpUrl) {
|
if (mcpServerConfig.httpUrl) {
|
||||||
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
||||||
if (mcpServerConfig.headers) {
|
|
||||||
|
// Set up headers with OAuth token if available
|
||||||
|
if (hasOAuthConfig && accessToken) {
|
||||||
|
transportOptions.requestInit = {
|
||||||
|
headers: {
|
||||||
|
...mcpServerConfig.headers,
|
||||||
|
Authorization: `Bearer ${accessToken}`,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else if (mcpServerConfig.headers) {
|
||||||
transportOptions.requestInit = {
|
transportOptions.requestInit = {
|
||||||
headers: mcpServerConfig.headers,
|
headers: mcpServerConfig.headers,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return new StreamableHTTPClientTransport(
|
return new StreamableHTTPClientTransport(
|
||||||
new URL(mcpServerConfig.httpUrl),
|
new URL(mcpServerConfig.httpUrl),
|
||||||
transportOptions,
|
transportOptions,
|
||||||
|
@ -391,11 +913,21 @@ export function createTransport(
|
||||||
|
|
||||||
if (mcpServerConfig.url) {
|
if (mcpServerConfig.url) {
|
||||||
const transportOptions: SSEClientTransportOptions = {};
|
const transportOptions: SSEClientTransportOptions = {};
|
||||||
if (mcpServerConfig.headers) {
|
|
||||||
|
// Set up headers with OAuth token if available
|
||||||
|
if (hasOAuthConfig && accessToken) {
|
||||||
|
transportOptions.requestInit = {
|
||||||
|
headers: {
|
||||||
|
...mcpServerConfig.headers,
|
||||||
|
Authorization: `Bearer ${accessToken}`,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else if (mcpServerConfig.headers) {
|
||||||
transportOptions.requestInit = {
|
transportOptions.requestInit = {
|
||||||
headers: mcpServerConfig.headers,
|
headers: mcpServerConfig.headers,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return new SSEClientTransport(
|
return new SSEClientTransport(
|
||||||
new URL(mcpServerConfig.url),
|
new URL(mcpServerConfig.url),
|
||||||
transportOptions,
|
transportOptions,
|
||||||
|
|
|
@ -325,5 +325,77 @@ describe('DiscoveredMCPTool', () => {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle Cancel confirmation outcome', async () => {
|
||||||
|
const tool = new DiscoveredMCPTool(
|
||||||
|
mockCallableToolInstance,
|
||||||
|
serverName,
|
||||||
|
serverToolName,
|
||||||
|
baseDescription,
|
||||||
|
inputSchema,
|
||||||
|
);
|
||||||
|
const confirmation = await tool.shouldConfirmExecute(
|
||||||
|
{},
|
||||||
|
new AbortController().signal,
|
||||||
|
);
|
||||||
|
expect(confirmation).not.toBe(false);
|
||||||
|
if (
|
||||||
|
confirmation &&
|
||||||
|
typeof confirmation === 'object' &&
|
||||||
|
'onConfirm' in confirmation &&
|
||||||
|
typeof confirmation.onConfirm === 'function'
|
||||||
|
) {
|
||||||
|
// Cancel should not add anything to allowlist
|
||||||
|
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
|
||||||
|
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
(DiscoveredMCPTool as any).allowlist.has(
|
||||||
|
`${serverName}.${serverToolName}`,
|
||||||
|
),
|
||||||
|
).toBe(false);
|
||||||
|
} else {
|
||||||
|
throw new Error(
|
||||||
|
'Confirmation details or onConfirm not in expected format',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle ProceedOnce confirmation outcome', async () => {
|
||||||
|
const tool = new DiscoveredMCPTool(
|
||||||
|
mockCallableToolInstance,
|
||||||
|
serverName,
|
||||||
|
serverToolName,
|
||||||
|
baseDescription,
|
||||||
|
inputSchema,
|
||||||
|
);
|
||||||
|
const confirmation = await tool.shouldConfirmExecute(
|
||||||
|
{},
|
||||||
|
new AbortController().signal,
|
||||||
|
);
|
||||||
|
expect(confirmation).not.toBe(false);
|
||||||
|
if (
|
||||||
|
confirmation &&
|
||||||
|
typeof confirmation === 'object' &&
|
||||||
|
'onConfirm' in confirmation &&
|
||||||
|
typeof confirmation.onConfirm === 'function'
|
||||||
|
) {
|
||||||
|
// ProceedOnce should not add anything to allowlist
|
||||||
|
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||||
|
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
expect(
|
||||||
|
(DiscoveredMCPTool as any).allowlist.has(
|
||||||
|
`${serverName}.${serverToolName}`,
|
||||||
|
),
|
||||||
|
).toBe(false);
|
||||||
|
} else {
|
||||||
|
throw new Error(
|
||||||
|
'Confirmation details or onConfirm not in expected format',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -173,6 +173,30 @@ export class ToolRegistry {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discover or re-discover tools for a single MCP server.
|
||||||
|
* @param serverName - The name of the server to discover tools from.
|
||||||
|
*/
|
||||||
|
async discoverToolsForServer(serverName: string): Promise<void> {
|
||||||
|
// Remove any previously discovered tools from this server
|
||||||
|
for (const [name, tool] of this.tools.entries()) {
|
||||||
|
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
|
||||||
|
this.tools.delete(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mcpServers = this.config.getMcpServers() ?? {};
|
||||||
|
const serverConfig = mcpServers[serverName];
|
||||||
|
if (serverConfig) {
|
||||||
|
await discoverMcpTools(
|
||||||
|
{ [serverName]: serverConfig },
|
||||||
|
undefined,
|
||||||
|
this,
|
||||||
|
this.config.getDebugMode(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
|
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
|
||||||
const discoveryCmd = this.config.getToolDiscoveryCommand();
|
const discoveryCmd = this.config.getToolDiscoveryCommand();
|
||||||
if (!discoveryCmd) {
|
if (!discoveryCmd) {
|
||||||
|
@ -386,6 +410,19 @@ function _sanitizeParameters(schema: Schema | undefined, visited: Set<Schema>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle enum values - Gemini API only allows enum for STRING type
|
||||||
|
if (schema.enum && Array.isArray(schema.enum)) {
|
||||||
|
if (schema.type !== Type.STRING) {
|
||||||
|
// If enum is present but type is not STRING, convert type to STRING
|
||||||
|
schema.type = Type.STRING;
|
||||||
|
}
|
||||||
|
// Filter out null and undefined values, then convert remaining values to strings for Gemini API compatibility
|
||||||
|
schema.enum = schema.enum
|
||||||
|
.filter((value: unknown) => value !== null && value !== undefined)
|
||||||
|
.map((value: unknown) => String(value));
|
||||||
|
}
|
||||||
|
|
||||||
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
|
// Vertex AI only supports 'enum' and 'date-time' for STRING format.
|
||||||
if (schema.type === Type.STRING) {
|
if (schema.type === Type.STRING) {
|
||||||
if (
|
if (
|
||||||
|
|
Loading…
Reference in New Issue