From 258c8489092c0970db0f693ddf0956e17051316c Mon Sep 17 00:00:00 2001 From: Brian Ray <62354532+emeryray2002@users.noreply.github.com> Date: Tue, 22 Jul 2025 09:34:56 -0400 Subject: [PATCH] MCP OAuth Part 2 - MCP Client Integration (#4318) Co-authored-by: Greg Shikhman --- packages/core/src/config/config.ts | 6 + packages/core/src/core/client.ts | 7 + packages/core/src/core/geminiChat.ts | 5 + packages/core/src/tools/mcp-client.test.ts | 14 +- packages/core/src/tools/mcp-client.ts | 576 ++++++++++++++++++++- packages/core/src/tools/mcp-tool.test.ts | 72 +++ packages/core/src/tools/tool-registry.ts | 37 ++ 7 files changed, 689 insertions(+), 28 deletions(-) diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 3f406f85..0e3171bf 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -45,6 +45,10 @@ import { } from './models.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; import { shouldAttemptBrowserLaunch } from '../utils/browser.js'; +import { MCPOAuthConfig } from '../mcp/oauth-provider.js'; + +// Re-export OAuth config type +export type { MCPOAuthConfig }; export enum ApprovalMode { DEFAULT = 'default', @@ -112,6 +116,8 @@ export class MCPServerConfig { readonly includeTools?: string[], readonly excludeTools?: string[], readonly extensionName?: string, + // OAuth configuration + readonly oauth?: MCPOAuthConfig, ) {} } diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 9a3acae3..340b3dae 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -158,6 +158,13 @@ export class GeminiClient { this.getChat().setHistory(history); } + async setTools(): Promise { + const toolRegistry = await this.config.getToolRegistry(); + const toolDeclarations = toolRegistry.getFunctionDeclarations(); + const tools: Tool[] = [{ functionDeclarations: toolDeclarations }]; + this.getChat().setTools(tools); + } + async resetChat(): Promise { this.chat = await this.startChat(); } diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index e963b781..d1a7bdec 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -15,6 +15,7 @@ import { createUserContent, Part, GenerateContentResponseUsageMetadata, + Tool, } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; @@ -498,6 +499,10 @@ export class GeminiChat { this.history = history; } + setTools(tools: Tool[]): void { + this.generationConfig.tools = tools; + } + getFinalUsageMetadata( chunks: GenerateContentResponse[], ): GenerateContentResponseUsageMetadata | undefined { diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 09614442..fbd2a2d4 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -20,6 +20,8 @@ import * as GenAiLib from '@google/genai'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); vi.mock('@google/genai'); +vi.mock('../mcp/oauth-provider.js'); +vi.mock('../mcp/oauth-token-storage.js'); describe('mcp-client', () => { afterEach(() => { @@ -82,7 +84,7 @@ describe('mcp-client', () => { describe('should connect via httpUrl', () => { it('without headers', async () => { - const transport = createTransport( + const transport = await createTransport( 'test-server', { httpUrl: 'http://test-server', @@ -96,7 +98,7 @@ describe('mcp-client', () => { }); it('with headers', async () => { - const transport = createTransport( + const transport = await createTransport( 'test-server', { httpUrl: 'http://test-server', @@ -117,7 +119,7 @@ describe('mcp-client', () => { describe('should connect via url', () => { it('without headers', async () => { - const transport = createTransport( + const transport = await createTransport( 'test-server', { url: 'http://test-server', @@ -130,7 +132,7 @@ describe('mcp-client', () => { }); it('with headers', async () => { - const transport = createTransport( + const transport = await createTransport( 'test-server', { url: 'http://test-server', @@ -149,10 +151,10 @@ describe('mcp-client', () => { }); }); - it('should connect via command', () => { + it('should connect via command', async () => { const mockedTransport = vi.mocked(SdkClientStdioLib.StdioClientTransport); - createTransport( + await createTransport( 'test-server', { command: 'test-command', diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index b1786af0..457259e5 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -18,9 +18,11 @@ import { import { parse } from 'shell-quote'; import { MCPServerConfig } from '../config/config.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; - import { FunctionDeclaration, mcpToTool } from '@google/genai'; import { ToolRegistry } from './tool-registry.js'; +import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; +import { OAuthUtils } from '../mcp/oauth-utils.js'; +import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { OpenFilesNotificationSchema, IDE_SERVER_NAME, @@ -64,6 +66,11 @@ const mcpServerStatusesInternal: Map = new Map(); */ let mcpDiscoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; +/** + * Map to track which MCP servers have been discovered to require OAuth + */ +export const mcpServerRequiresOAuth: Map = new Map(); + /** * Event listeners for MCP server status changes */ @@ -131,6 +138,165 @@ export function getMCPDiscoveryState(): MCPDiscoveryState { return mcpDiscoveryState; } +/** + * Parse www-authenticate header to extract OAuth metadata URI. + * + * @param wwwAuthenticate The www-authenticate header value + * @returns The resource metadata URI if found, null otherwise + */ +function _parseWWWAuthenticate(wwwAuthenticate: string): string | null { + // Parse header like: Bearer realm="MCP Server", resource_metadata_uri="https://..." + const resourceMetadataMatch = wwwAuthenticate.match( + /resource_metadata_uri="([^"]+)"/, + ); + return resourceMetadataMatch ? resourceMetadataMatch[1] : null; +} + +/** + * Extract WWW-Authenticate header from error message string. + * This is a more robust approach than regex matching. + * + * @param errorString The error message string + * @returns The www-authenticate header value if found, null otherwise + */ +function extractWWWAuthenticateHeader(errorString: string): string | null { + // Try multiple patterns to extract the header + const patterns = [ + /www-authenticate:\s*([^\n\r]+)/i, + /WWW-Authenticate:\s*([^\n\r]+)/i, + /"www-authenticate":\s*"([^"]+)"/i, + /'www-authenticate':\s*'([^']+)'/i, + ]; + + for (const pattern of patterns) { + const match = errorString.match(pattern); + if (match) { + return match[1].trim(); + } + } + + return null; +} + +/** + * Handle automatic OAuth discovery and authentication for a server. + * + * @param mcpServerName The name of the MCP server + * @param mcpServerConfig The MCP server configuration + * @param wwwAuthenticate The www-authenticate header value + * @returns True if OAuth was successfully configured and authenticated, false otherwise + */ +async function handleAutomaticOAuth( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + wwwAuthenticate: string, +): Promise { + try { + console.log(`🔐 '${mcpServerName}' requires OAuth authentication`); + + // Always try to parse the resource metadata URI from the www-authenticate header + let oauthConfig; + const resourceMetadataUri = + OAuthUtils.parseWWWAuthenticateHeader(wwwAuthenticate); + if (resourceMetadataUri) { + oauthConfig = await OAuthUtils.discoverOAuthConfig(resourceMetadataUri); + } else if (mcpServerConfig.url) { + // Fallback: try to discover OAuth config from the base URL for SSE + const sseUrl = new URL(mcpServerConfig.url); + const baseUrl = `${sseUrl.protocol}//${sseUrl.host}`; + oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl); + } else if (mcpServerConfig.httpUrl) { + // Fallback: try to discover OAuth config from the base URL for HTTP + const httpUrl = new URL(mcpServerConfig.httpUrl); + const baseUrl = `${httpUrl.protocol}//${httpUrl.host}`; + oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl); + } + + if (!oauthConfig) { + console.error( + `❌ Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`, + ); + return false; + } + + // OAuth configuration discovered - proceed with authentication + + // Create OAuth configuration for authentication + const oauthAuthConfig = { + enabled: true, + authorizationUrl: oauthConfig.authorizationUrl, + tokenUrl: oauthConfig.tokenUrl, + scopes: oauthConfig.scopes || [], + }; + + // Perform OAuth authentication + console.log( + `Starting OAuth authentication for server '${mcpServerName}'...`, + ); + await MCPOAuthProvider.authenticate(mcpServerName, oauthAuthConfig); + + console.log( + `OAuth authentication successful for server '${mcpServerName}'`, + ); + return true; + } catch (error) { + console.error( + `Failed to handle automatic OAuth for server '${mcpServerName}': ${getErrorMessage(error)}`, + ); + return false; + } +} + +/** + * Create a transport with OAuth token for the given server configuration. + * + * @param mcpServerName The name of the MCP server + * @param mcpServerConfig The MCP server configuration + * @param accessToken The OAuth access token + * @returns The transport with OAuth token, or null if creation fails + */ +async function createTransportWithOAuth( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + accessToken: string, +): Promise { + try { + if (mcpServerConfig.httpUrl) { + // Create HTTP transport with OAuth token + const oauthTransportOptions: StreamableHTTPClientTransportOptions = { + requestInit: { + headers: { + ...mcpServerConfig.headers, + Authorization: `Bearer ${accessToken}`, + }, + }, + }; + + return new StreamableHTTPClientTransport( + new URL(mcpServerConfig.httpUrl), + oauthTransportOptions, + ); + } else if (mcpServerConfig.url) { + // Create SSE transport with OAuth token in Authorization header + return new SSEClientTransport(new URL(mcpServerConfig.url), { + requestInit: { + headers: { + ...mcpServerConfig.headers, + Authorization: `Bearer ${accessToken}`, + }, + }, + }); + } + + return null; + } catch (error) { + console.error( + `Failed to create OAuth transport for server '${mcpServerName}': ${getErrorMessage(error)}`, + ); + return null; + } +} + /** * Discovers tools from all configured MCP servers and registers them with the tool registry. * It orchestrates the connection and discovery process for each server defined in the @@ -334,7 +500,7 @@ export async function connectToMcpServer( } try { - const transport = createTransport( + const transport = await createTransport( mcpServerName, mcpServerConfig, debugMode, @@ -349,40 +515,396 @@ export async function connectToMcpServer( throw error; } } catch (error) { - // Create a safe config object that excludes sensitive information - const safeConfig = { - command: mcpServerConfig.command, - url: mcpServerConfig.url, - httpUrl: mcpServerConfig.httpUrl, - cwd: mcpServerConfig.cwd, - timeout: mcpServerConfig.timeout, - trust: mcpServerConfig.trust, - // Exclude args, env, and headers which may contain sensitive data - }; + // Check if this is a 401 error that might indicate OAuth is required + const errorString = String(error); + if ( + errorString.includes('401') && + (mcpServerConfig.httpUrl || mcpServerConfig.url) + ) { + mcpServerRequiresOAuth.set(mcpServerName, true); + // Only trigger automatic OAuth discovery for HTTP servers or when OAuth is explicitly configured + // For SSE servers, we should not trigger new OAuth flows automatically + const shouldTriggerOAuth = + mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled; - let errorString = - `failed to start or connect to MCP server '${mcpServerName}' ` + - `${JSON.stringify(safeConfig)}; \n${error}`; - if (process.env.SANDBOX) { - errorString += `\nMake sure it is available in the sandbox`; + if (!shouldTriggerOAuth) { + // For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately. + const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName); + if (credentials) { + const hasStoredTokens = await MCPOAuthProvider.getValidToken( + mcpServerName, + { + // Pass client ID if available + clientId: credentials.clientId, + }, + ); + if (hasStoredTokens) { + console.log( + `Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` + + `Please re-authenticate using: /mcp auth ${mcpServerName}`, + ); + } else { + console.log( + `401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` + + `Please authenticate using: /mcp auth ${mcpServerName}`, + ); + } + } + throw new Error( + `401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` + + `Please authenticate using: /mcp auth ${mcpServerName}`, + ); + } + + // Try to extract www-authenticate header from the error + let wwwAuthenticate = extractWWWAuthenticateHeader(errorString); + + // If we didn't get the header from the error string, try to get it from the server + if (!wwwAuthenticate && mcpServerConfig.url) { + console.log( + `No www-authenticate header in error, trying to fetch it from server...`, + ); + try { + const response = await fetch(mcpServerConfig.url, { + method: 'HEAD', + headers: { + Accept: 'text/event-stream', + }, + signal: AbortSignal.timeout(5000), + }); + + if (response.status === 401) { + wwwAuthenticate = response.headers.get('www-authenticate'); + if (wwwAuthenticate) { + console.log( + `Found www-authenticate header from server: ${wwwAuthenticate}`, + ); + } + } + } catch (fetchError) { + console.debug( + `Failed to fetch www-authenticate header: ${getErrorMessage(fetchError)}`, + ); + } + } + + if (wwwAuthenticate) { + console.log( + `Received 401 with www-authenticate header: ${wwwAuthenticate}`, + ); + + // Try automatic OAuth discovery and authentication + const oauthSuccess = await handleAutomaticOAuth( + mcpServerName, + mcpServerConfig, + wwwAuthenticate, + ); + if (oauthSuccess) { + // Retry connection with OAuth token + console.log( + `Retrying connection to '${mcpServerName}' with OAuth token...`, + ); + + // Get the valid token - we need to create a proper OAuth config + // The token should already be available from the authentication process + const credentials = + await MCPOAuthTokenStorage.getToken(mcpServerName); + if (credentials) { + const accessToken = await MCPOAuthProvider.getValidToken( + mcpServerName, + { + // Pass client ID if available + clientId: credentials.clientId, + }, + ); + + if (accessToken) { + // Create transport with OAuth token + const oauthTransport = await createTransportWithOAuth( + mcpServerName, + mcpServerConfig, + accessToken, + ); + if (oauthTransport) { + try { + await mcpClient.connect(oauthTransport, { + timeout: + mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); + // Connection successful with OAuth + return mcpClient; + } catch (retryError) { + console.error( + `Failed to connect with OAuth token: ${getErrorMessage( + retryError, + )}`, + ); + throw retryError; + } + } else { + console.error( + `Failed to create OAuth transport for server '${mcpServerName}'`, + ); + throw new Error( + `Failed to create OAuth transport for server '${mcpServerName}'`, + ); + } + } else { + console.error( + `Failed to get OAuth token for server '${mcpServerName}'`, + ); + throw new Error( + `Failed to get OAuth token for server '${mcpServerName}'`, + ); + } + } else { + console.error( + `Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`, + ); + throw new Error( + `Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`, + ); + } + } else { + console.error( + `Failed to handle automatic OAuth for server '${mcpServerName}'`, + ); + throw new Error( + `Failed to handle automatic OAuth for server '${mcpServerName}'`, + ); + } + } else { + // No www-authenticate header found, but we got a 401 + // Only try OAuth discovery for HTTP servers or when OAuth is explicitly configured + // For SSE servers, we should not trigger new OAuth flows automatically + const shouldTryDiscovery = + mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled; + + if (!shouldTryDiscovery) { + const credentials = + await MCPOAuthTokenStorage.getToken(mcpServerName); + if (credentials) { + const hasStoredTokens = await MCPOAuthProvider.getValidToken( + mcpServerName, + { + // Pass client ID if available + clientId: credentials.clientId, + }, + ); + if (hasStoredTokens) { + console.log( + `Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` + + `Please re-authenticate using: /mcp auth ${mcpServerName}`, + ); + } else { + console.log( + `401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` + + `Please authenticate using: /mcp auth ${mcpServerName}`, + ); + } + } + throw new Error( + `401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` + + `Please authenticate using: /mcp auth ${mcpServerName}`, + ); + } + + // For SSE servers, try to discover OAuth configuration from the base URL + console.log(`🔍 Attempting OAuth discovery for '${mcpServerName}'...`); + + if (mcpServerConfig.url) { + const sseUrl = new URL(mcpServerConfig.url); + const baseUrl = `${sseUrl.protocol}//${sseUrl.host}`; + + try { + // Try to discover OAuth configuration from the base URL + const oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl); + if (oauthConfig) { + console.log( + `Discovered OAuth configuration from base URL for server '${mcpServerName}'`, + ); + + // Create OAuth configuration for authentication + const oauthAuthConfig = { + enabled: true, + authorizationUrl: oauthConfig.authorizationUrl, + tokenUrl: oauthConfig.tokenUrl, + scopes: oauthConfig.scopes || [], + }; + + // Perform OAuth authentication + console.log( + `Starting OAuth authentication for server '${mcpServerName}'...`, + ); + await MCPOAuthProvider.authenticate( + mcpServerName, + oauthAuthConfig, + ); + + // Retry connection with OAuth token + const credentials = + await MCPOAuthTokenStorage.getToken(mcpServerName); + if (credentials) { + const accessToken = await MCPOAuthProvider.getValidToken( + mcpServerName, + { + // Pass client ID if available + clientId: credentials.clientId, + }, + ); + if (accessToken) { + // Create transport with OAuth token + const oauthTransport = await createTransportWithOAuth( + mcpServerName, + mcpServerConfig, + accessToken, + ); + if (oauthTransport) { + try { + await mcpClient.connect(oauthTransport, { + timeout: + mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); + // Connection successful with OAuth + return mcpClient; + } catch (retryError) { + console.error( + `Failed to connect with OAuth token: ${getErrorMessage( + retryError, + )}`, + ); + throw retryError; + } + } else { + console.error( + `Failed to create OAuth transport for server '${mcpServerName}'`, + ); + throw new Error( + `Failed to create OAuth transport for server '${mcpServerName}'`, + ); + } + } else { + console.error( + `Failed to get OAuth token for server '${mcpServerName}'`, + ); + throw new Error( + `Failed to get OAuth token for server '${mcpServerName}'`, + ); + } + } else { + console.error( + `Failed to get stored credentials for server '${mcpServerName}'`, + ); + throw new Error( + `Failed to get stored credentials for server '${mcpServerName}'`, + ); + } + } else { + console.error( + `❌ Could not configure OAuth for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`, + ); + throw new Error( + `OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`, + ); + } + } catch (discoveryError) { + console.error( + `❌ OAuth discovery failed for '${mcpServerName}' - please authenticate manually with /mcp auth ${mcpServerName}`, + ); + throw discoveryError; + } + } else { + console.error( + `❌ '${mcpServerName}' requires authentication but no OAuth configuration found`, + ); + throw new Error( + `MCP server '${mcpServerName}' requires authentication. Please configure OAuth or check server settings.`, + ); + } + } + } else { + // Handle other connection errors + // Create a concise error message + const errorMessage = (error as Error).message || String(error); + const isNetworkError = + errorMessage.includes('ENOTFOUND') || + errorMessage.includes('ECONNREFUSED'); + + let conciseError: string; + if (isNetworkError) { + conciseError = `Cannot connect to '${mcpServerName}' - server may be down or URL incorrect`; + } else { + conciseError = `Connection failed for '${mcpServerName}': ${errorMessage}`; + } + + if (process.env.SANDBOX) { + conciseError += ` (check sandbox availability)`; + } + + throw new Error(conciseError); } - throw new Error(errorString); } } /** Visible for Testing */ -export function createTransport( +export async function createTransport( mcpServerName: string, mcpServerConfig: MCPServerConfig, debugMode: boolean, -): Transport { +): Promise { + // Check if we have OAuth configuration or stored tokens + let accessToken: string | null = null; + let hasOAuthConfig = mcpServerConfig.oauth?.enabled; + + if (hasOAuthConfig && mcpServerConfig.oauth) { + accessToken = await MCPOAuthProvider.getValidToken( + mcpServerName, + mcpServerConfig.oauth, + ); + + if (!accessToken) { + console.error( + `MCP server '${mcpServerName}' requires OAuth authentication. ` + + `Please authenticate using the /mcp auth command.`, + ); + throw new Error( + `MCP server '${mcpServerName}' requires OAuth authentication. ` + + `Please authenticate using the /mcp auth command.`, + ); + } + } else { + // Check if we have stored OAuth tokens for this server (from previous authentication) + const credentials = await MCPOAuthTokenStorage.getToken(mcpServerName); + if (credentials) { + accessToken = await MCPOAuthProvider.getValidToken(mcpServerName, { + // Pass client ID if available + clientId: credentials.clientId, + }); + + if (accessToken) { + hasOAuthConfig = true; + console.log(`Found stored OAuth token for server '${mcpServerName}'`); + } + } + } + if (mcpServerConfig.httpUrl) { const transportOptions: StreamableHTTPClientTransportOptions = {}; - if (mcpServerConfig.headers) { + + // Set up headers with OAuth token if available + if (hasOAuthConfig && accessToken) { + transportOptions.requestInit = { + headers: { + ...mcpServerConfig.headers, + Authorization: `Bearer ${accessToken}`, + }, + }; + } else if (mcpServerConfig.headers) { transportOptions.requestInit = { headers: mcpServerConfig.headers, }; } + return new StreamableHTTPClientTransport( new URL(mcpServerConfig.httpUrl), transportOptions, @@ -391,11 +913,21 @@ export function createTransport( if (mcpServerConfig.url) { const transportOptions: SSEClientTransportOptions = {}; - if (mcpServerConfig.headers) { + + // Set up headers with OAuth token if available + if (hasOAuthConfig && accessToken) { + transportOptions.requestInit = { + headers: { + ...mcpServerConfig.headers, + Authorization: `Bearer ${accessToken}`, + }, + }; + } else if (mcpServerConfig.headers) { transportOptions.requestInit = { headers: mcpServerConfig.headers, }; } + return new SSEClientTransport( new URL(mcpServerConfig.url), transportOptions, diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index 2e700710..b5843b95 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -325,5 +325,77 @@ describe('DiscoveredMCPTool', () => { ); } }); + + it('should handle Cancel confirmation outcome', async () => { + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + ); + const confirmation = await tool.shouldConfirmExecute( + {}, + new AbortController().signal, + ); + expect(confirmation).not.toBe(false); + if ( + confirmation && + typeof confirmation === 'object' && + 'onConfirm' in confirmation && + typeof confirmation.onConfirm === 'function' + ) { + // Cancel should not add anything to allowlist + await confirmation.onConfirm(ToolConfirmationOutcome.Cancel); + expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe( + false, + ); + expect( + (DiscoveredMCPTool as any).allowlist.has( + `${serverName}.${serverToolName}`, + ), + ).toBe(false); + } else { + throw new Error( + 'Confirmation details or onConfirm not in expected format', + ); + } + }); + + it('should handle ProceedOnce confirmation outcome', async () => { + const tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + ); + const confirmation = await tool.shouldConfirmExecute( + {}, + new AbortController().signal, + ); + expect(confirmation).not.toBe(false); + if ( + confirmation && + typeof confirmation === 'object' && + 'onConfirm' in confirmation && + typeof confirmation.onConfirm === 'function' + ) { + // ProceedOnce should not add anything to allowlist + await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce); + expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe( + false, + ); + expect( + (DiscoveredMCPTool as any).allowlist.has( + `${serverName}.${serverToolName}`, + ), + ).toBe(false); + } else { + throw new Error( + 'Confirmation details or onConfirm not in expected format', + ); + } + }); }); }); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index d6e84de3..a6742c06 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -173,6 +173,30 @@ export class ToolRegistry { ); } + /** + * Discover or re-discover tools for a single MCP server. + * @param serverName - The name of the server to discover tools from. + */ + async discoverToolsForServer(serverName: string): Promise { + // Remove any previously discovered tools from this server + for (const [name, tool] of this.tools.entries()) { + if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) { + this.tools.delete(name); + } + } + + const mcpServers = this.config.getMcpServers() ?? {}; + const serverConfig = mcpServers[serverName]; + if (serverConfig) { + await discoverMcpTools( + { [serverName]: serverConfig }, + undefined, + this, + this.config.getDebugMode(), + ); + } + } + private async discoverAndRegisterToolsFromCommand(): Promise { const discoveryCmd = this.config.getToolDiscoveryCommand(); if (!discoveryCmd) { @@ -386,6 +410,19 @@ function _sanitizeParameters(schema: Schema | undefined, visited: Set) { } } } + + // Handle enum values - Gemini API only allows enum for STRING type + if (schema.enum && Array.isArray(schema.enum)) { + if (schema.type !== Type.STRING) { + // If enum is present but type is not STRING, convert type to STRING + schema.type = Type.STRING; + } + // Filter out null and undefined values, then convert remaining values to strings for Gemini API compatibility + schema.enum = schema.enum + .filter((value: unknown) => value !== null && value !== undefined) + .map((value: unknown) => String(value)); + } + // Vertex AI only supports 'enum' and 'date-time' for STRING format. if (schema.type === Type.STRING) { if (