Refactor MCP code for reuse and testing (#3880)

This commit is contained in:
Tommaso Sciortino 2025-07-14 11:19:33 -07:00 committed by GitHub
parent 9dc812dd4b
commit 5008aea90d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 613 additions and 1167 deletions

File diff suppressed because it is too large Load Diff

View File

@ -5,6 +5,7 @@
*/ */
import { Client } from '@modelcontextprotocol/sdk/client/index.js'; import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
import { import {
SSEClientTransport, SSEClientTransport,
@ -17,7 +18,7 @@ import {
import { parse } from 'shell-quote'; import { parse } from 'shell-quote';
import { MCPServerConfig } from '../config/config.js'; import { MCPServerConfig } from '../config/config.js';
import { DiscoveredMCPTool } from './mcp-tool.js'; import { DiscoveredMCPTool } from './mcp-tool.js';
import { Type, mcpToTool } from '@google/genai'; import { FunctionDeclaration, Type, mcpToTool } from '@google/genai';
import { sanitizeParameters, ToolRegistry } from './tool-registry.js'; import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
@ -123,16 +124,46 @@ export function getMCPDiscoveryState(): MCPDiscoveryState {
return mcpDiscoveryState; return mcpDiscoveryState;
} }
/**
* Discovers tools from all configured MCP servers and registers them with the tool registry.
* It orchestrates the connection and discovery process for each server defined in the
* configuration, as well as any server specified via a command-line argument.
*
* @param mcpServers A record of named MCP server configurations.
* @param mcpServerCommand An optional command string for a dynamically specified MCP server.
* @param toolRegistry The central registry where discovered tools will be registered.
* @returns A promise that resolves when the discovery process has been attempted for all servers.
*/
export async function discoverMcpTools( export async function discoverMcpTools(
mcpServers: Record<string, MCPServerConfig>, mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined, mcpServerCommand: string | undefined,
toolRegistry: ToolRegistry, toolRegistry: ToolRegistry,
debugMode: boolean, debugMode: boolean,
): Promise<void> { ): Promise<void> {
// Set discovery state to in progress
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
try { try {
mcpServers = populateMcpServerCommand(mcpServers, mcpServerCommand);
const discoveryPromises = Object.entries(mcpServers).map(
([mcpServerName, mcpServerConfig]) =>
connectAndDiscover(
mcpServerName,
mcpServerConfig,
toolRegistry,
debugMode,
),
);
await Promise.all(discoveryPromises);
} finally {
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
}
}
/** Visible for Testing */
export function populateMcpServerCommand(
mcpServers: Record<string, MCPServerConfig>,
mcpServerCommand: string | undefined,
): Record<string, MCPServerConfig> {
if (mcpServerCommand) { if (mcpServerCommand) {
const cmd = mcpServerCommand; const cmd = mcpServerCommand;
const args = parse(cmd, process.env) as string[]; const args = parse(cmd, process.env) as string[];
@ -145,25 +176,7 @@ export async function discoverMcpTools(
args: args.slice(1), args: args.slice(1),
}; };
} }
return mcpServers;
const discoveryPromises = Object.entries(mcpServers).map(
([mcpServerName, mcpServerConfig]) =>
connectAndDiscover(
mcpServerName,
mcpServerConfig,
toolRegistry,
debugMode,
),
);
await Promise.all(discoveryPromises);
// Mark discovery as completed
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
} catch (error) {
// Still mark as completed even with errors
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
throw error;
}
} }
/** /**
@ -176,71 +189,117 @@ export async function discoverMcpTools(
* @param toolRegistry The registry to register discovered tools with * @param toolRegistry The registry to register discovered tools with
* @returns Promise that resolves when discovery is complete * @returns Promise that resolves when discovery is complete
*/ */
async function connectAndDiscover( export async function connectAndDiscover(
mcpServerName: string, mcpServerName: string,
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
toolRegistry: ToolRegistry, toolRegistry: ToolRegistry,
debugMode: boolean, debugMode: boolean,
): Promise<void> { ): Promise<void> {
// Initialize the server status as connecting
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
let transport; try {
if (mcpServerConfig.httpUrl) { const mcpClient = await connectToMcpServer(
const transportOptions: StreamableHTTPClientTransportOptions = {}; mcpServerName,
mcpServerConfig,
debugMode,
);
try {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
if (mcpServerConfig.headers) { mcpClient.onerror = (error) => {
transportOptions.requestInit = { console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
headers: mcpServerConfig.headers,
};
}
transport = new StreamableHTTPClientTransport(
new URL(mcpServerConfig.httpUrl),
transportOptions,
);
} else if (mcpServerConfig.url) {
const transportOptions: SSEClientTransportOptions = {};
if (mcpServerConfig.headers) {
transportOptions.requestInit = {
headers: mcpServerConfig.headers,
};
}
transport = new SSEClientTransport(
new URL(mcpServerConfig.url),
transportOptions,
);
} else if (mcpServerConfig.command) {
transport = new StdioClientTransport({
command: mcpServerConfig.command,
args: mcpServerConfig.args || [],
env: {
...process.env,
...(mcpServerConfig.env || {}),
} as Record<string, string>,
cwd: mcpServerConfig.cwd,
stderr: 'pipe',
});
} else {
console.error(
`MCP server '${mcpServerName}' has invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio). Skipping.`,
);
// Update status to disconnected
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
return; };
const tools = await discoverTools(
mcpServerName,
mcpServerConfig,
mcpClient,
);
for (const tool of tools) {
toolRegistry.registerTool(tool);
}
} catch (error) {
mcpClient.close();
throw error;
}
} catch (error) {
console.error(`Error connecting to MCP server '${mcpServerName}':`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}
} }
if ( /**
debugMode && * Discovers and sanitizes tools from a connected MCP client.
transport instanceof StdioClientTransport && * It retrieves function declarations from the client, filters out disabled tools,
transport.stderr * generates valid names for them, and wraps them in `DiscoveredMCPTool` instances.
) { *
transport.stderr.on('data', (data) => { * @param mcpServerName The name of the MCP server.
const stderrStr = data.toString().trim(); * @param mcpServerConfig The configuration for the MCP server.
console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr); * @param mcpClient The active MCP client instance.
}); * @returns A promise that resolves to an array of discovered and enabled tools.
* @throws An error if no enabled tools are found or if the server provides invalid function declarations.
*/
export async function discoverTools(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
mcpClient: Client,
): Promise<DiscoveredMCPTool[]> {
try {
const mcpCallableTool = mcpToTool(mcpClient);
const tool = await mcpCallableTool.tool();
if (!Array.isArray(tool.functionDeclarations)) {
throw new Error(`Server did not return valid function declarations.`);
} }
const discoveredTools: DiscoveredMCPTool[] = [];
for (const funcDecl of tool.functionDeclarations) {
if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) {
continue;
}
const toolNameForModel = generateValidName(funcDecl, mcpServerName);
sanitizeParameters(funcDecl.parameters);
discoveredTools.push(
new DiscoveredMCPTool(
mcpCallableTool,
mcpServerName,
toolNameForModel,
funcDecl.description ?? '',
funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
funcDecl.name!,
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
mcpServerConfig.trust,
),
);
}
if (discoveredTools.length === 0) {
throw Error('No enabled tools found');
}
return discoveredTools;
} catch (error) {
throw new Error(`Error discovering tools: ${error}`);
}
}
/**
* Creates and connects an MCP client to a server based on the provided configuration.
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
* establishes a connection. It also applies a patch to handle request timeouts.
*
* @param mcpServerName The name of the MCP server, used for logging and identification.
* @param mcpServerConfig The configuration specifying how to connect to the server.
* @returns A promise that resolves to a connected MCP `Client` instance.
* @throws An error if the connection fails or the configuration is invalid.
*/
export async function connectToMcpServer(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
debugMode: boolean,
): Promise<Client> {
const mcpClient = new Client({ const mcpClient = new Client({
name: 'gemini-cli-mcp-client', name: 'gemini-cli-mcp-client',
version: '0.0.1', version: '0.0.1',
@ -258,12 +317,21 @@ async function connectAndDiscover(
}; };
} }
try {
const transport = createTransport(
mcpServerName,
mcpServerConfig,
debugMode,
);
try { try {
await mcpClient.connect(transport, { await mcpClient.connect(transport, {
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
}); });
// Connection successful return mcpClient;
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED); } catch (error) {
await transport.close();
throw error;
}
} catch (error) { } catch (error) {
// Create a safe config object that excludes sensitive information // Create a safe config object that excludes sensitive information
const safeConfig = { const safeConfig = {
@ -282,131 +350,110 @@ async function connectAndDiscover(
if (process.env.SANDBOX) { if (process.env.SANDBOX) {
errorString += `\nMake sure it is available in the sandbox`; errorString += `\nMake sure it is available in the sandbox`;
} }
console.error(errorString); throw new Error(errorString);
// Update status to disconnected }
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
return;
} }
mcpClient.onerror = (error) => { /** Visible for Testing */
console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); export function createTransport(
// Update status to disconnected on error mcpServerName: string,
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); mcpServerConfig: MCPServerConfig,
debugMode: boolean,
): Transport {
if (mcpServerConfig.httpUrl) {
const transportOptions: StreamableHTTPClientTransportOptions = {};
if (mcpServerConfig.headers) {
transportOptions.requestInit = {
headers: mcpServerConfig.headers,
}; };
}
try { return new StreamableHTTPClientTransport(
const mcpCallableTool = mcpToTool(mcpClient); new URL(mcpServerConfig.httpUrl),
const tool = await mcpCallableTool.tool(); transportOptions,
if (!tool || !Array.isArray(tool.functionDeclarations)) {
console.error(
`MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`,
); );
if (
transport instanceof StdioClientTransport ||
transport instanceof SSEClientTransport ||
transport instanceof StreamableHTTPClientTransport
) {
await transport.close();
}
// Update status to disconnected
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
return;
} }
for (const funcDecl of tool.functionDeclarations) { if (mcpServerConfig.url) {
const transportOptions: SSEClientTransportOptions = {};
if (mcpServerConfig.headers) {
transportOptions.requestInit = {
headers: mcpServerConfig.headers,
};
}
return new SSEClientTransport(
new URL(mcpServerConfig.url),
transportOptions,
);
}
if (mcpServerConfig.command) {
const transport = new StdioClientTransport({
command: mcpServerConfig.command,
args: mcpServerConfig.args || [],
env: {
...process.env,
...(mcpServerConfig.env || {}),
} as Record<string, string>,
cwd: mcpServerConfig.cwd,
stderr: 'pipe',
});
if (debugMode) {
transport.stderr!.on('data', (data) => {
const stderrStr = data.toString().trim();
console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr);
});
}
return transport;
}
throw new Error(
`Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`,
);
}
/** Visible for testing */
export function generateValidName(
funcDecl: FunctionDeclaration,
mcpServerName: string,
) {
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_');
// Prepend MCP server name to avoid conflicts with other tools
validToolname = mcpServerName + '__' + validToolname;
// If longer than 63 characters, replace middle with '___'
// (Gemini API says max length 64, but actual limit seems to be 63)
if (validToolname.length > 63) {
validToolname =
validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
}
return validToolname;
}
/** Visible for testing */
export function isEnabled(
funcDecl: FunctionDeclaration,
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
): boolean {
if (!funcDecl.name) { if (!funcDecl.name) {
console.warn( console.warn(
`Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`, `Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`,
); );
continue; return false;
} }
const { includeTools, excludeTools } = mcpServerConfig; const { includeTools, excludeTools } = mcpServerConfig;
const toolName = funcDecl.name;
let isEnabled = false; // excludeTools takes precedence over includeTools
if (includeTools === undefined) { if (excludeTools && excludeTools.includes(funcDecl.name)) {
isEnabled = true; return false;
} else { }
isEnabled = includeTools.some(
(tool) => tool === toolName || tool.startsWith(`${toolName}(`), return (
!includeTools ||
includeTools.some(
(tool) => tool === funcDecl.name || tool.startsWith(`${funcDecl.name}(`),
)
); );
} }
if (excludeTools?.includes(toolName)) {
isEnabled = false;
}
if (!isEnabled) {
continue;
}
let toolNameForModel = funcDecl.name;
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
toolNameForModel = toolNameForModel.replace(/[^a-zA-Z0-9_.-]/g, '_');
const existingTool = toolRegistry.getTool(toolNameForModel);
if (existingTool) {
toolNameForModel = mcpServerName + '__' + toolNameForModel;
}
// If longer than 63 characters, replace middle with '___'
// (Gemini API says max length 64, but actual limit seems to be 63)
if (toolNameForModel.length > 63) {
toolNameForModel =
toolNameForModel.slice(0, 28) + '___' + toolNameForModel.slice(-32);
}
sanitizeParameters(funcDecl.parameters);
toolRegistry.registerTool(
new DiscoveredMCPTool(
mcpCallableTool,
mcpServerName,
toolNameForModel,
funcDecl.description ?? '',
funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
funcDecl.name,
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
mcpServerConfig.trust,
),
);
}
} catch (error) {
console.error(
`Failed to list or register tools for MCP server '${mcpServerName}': ${error}`,
);
// Ensure transport is cleaned up on error too
if (
transport instanceof StdioClientTransport ||
transport instanceof SSEClientTransport ||
transport instanceof StreamableHTTPClientTransport
) {
await transport.close();
}
// Update status to disconnected
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}
// If no tools were registered from this MCP server, the following 'if' block
// will close the connection. This is done to conserve resources and prevent
// an orphaned connection to a server that isn't providing any usable
// functionality. Connections to servers that did provide tools are kept
// open, as those tools will require the connection to function.
if (toolRegistry.getToolsByServer(mcpServerName).length === 0) {
console.log(
`No tools registered from MCP server '${mcpServerName}'. Closing connection.`,
);
if (
transport instanceof StdioClientTransport ||
transport instanceof SSEClientTransport ||
transport instanceof StreamableHTTPClientTransport
) {
await transport.close();
// Update status to disconnected
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
}
}
}

View File

@ -326,6 +326,83 @@ describe('ToolRegistry', () => {
}); });
describe('sanitizeParameters', () => { describe('sanitizeParameters', () => {
it('should remove default when anyOf is present', () => {
const schema: Schema = {
anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }],
default: 'hello',
};
sanitizeParameters(schema);
expect(schema.default).toBeUndefined();
});
it('should recursively sanitize items in anyOf', () => {
const schema: Schema = {
anyOf: [
{
anyOf: [{ type: Type.STRING }],
default: 'world',
},
{ type: Type.NUMBER },
],
};
sanitizeParameters(schema);
expect(schema.anyOf![0].default).toBeUndefined();
});
it('should recursively sanitize items in items', () => {
const schema: Schema = {
items: {
anyOf: [{ type: Type.STRING }],
default: 'world',
},
};
sanitizeParameters(schema);
expect(schema.items!.default).toBeUndefined();
});
it('should recursively sanitize items in properties', () => {
const schema: Schema = {
properties: {
prop1: {
anyOf: [{ type: Type.STRING }],
default: 'world',
},
},
};
sanitizeParameters(schema);
expect(schema.properties!.prop1.default).toBeUndefined();
});
it('should handle complex nested schemas', () => {
const schema: Schema = {
properties: {
prop1: {
items: {
anyOf: [{ type: Type.STRING }],
default: 'world',
},
},
prop2: {
anyOf: [
{
properties: {
nestedProp: {
anyOf: [{ type: Type.NUMBER }],
default: 123,
},
},
},
],
},
},
};
sanitizeParameters(schema);
expect(schema.properties!.prop1.items!.default).toBeUndefined();
const nestedProp =
schema.properties!.prop2.anyOf![0].properties!.nestedProp;
expect(nestedProp?.default).toBeUndefined();
});
it('should remove unsupported format from a simple string property', () => { it('should remove unsupported format from a simple string property', () => {
const schema: Schema = { const schema: Schema = {
type: Type.OBJECT, type: Type.OBJECT,
@ -356,25 +433,6 @@ describe('sanitizeParameters', () => {
expect(schema).toEqual(originalSchema); expect(schema).toEqual(originalSchema);
}); });
it('should handle nested objects recursively', () => {
const schema: Schema = {
type: Type.OBJECT,
properties: {
user: {
type: Type.OBJECT,
properties: {
email: { type: Type.STRING, format: 'email' },
},
},
},
};
sanitizeParameters(schema);
expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty(
'format',
undefined,
);
});
it('should handle arrays of objects', () => { it('should handle arrays of objects', () => {
const schema: Schema = { const schema: Schema = {
type: Type.OBJECT, type: Type.OBJECT,
@ -414,19 +472,6 @@ describe('sanitizeParameters', () => {
expect(() => sanitizeParameters(undefined)).not.toThrow(); expect(() => sanitizeParameters(undefined)).not.toThrow();
}); });
it('should handle cyclic schemas without crashing', () => {
const schema: any = {
type: Type.OBJECT,
properties: {
name: { type: Type.STRING, format: 'hostname' },
},
};
schema.properties.self = schema;
expect(() => sanitizeParameters(schema)).not.toThrow();
expect(schema.properties.name).toHaveProperty('format', undefined);
});
it('should handle complex nested schemas with cycles', () => { it('should handle complex nested schemas with cycles', () => {
const userNode: any = { const userNode: any = {
type: Type.OBJECT, type: Type.OBJECT,