Refactor MCP code for reuse and testing (#3880)
This commit is contained in:
parent
9dc812dd4b
commit
5008aea90d
File diff suppressed because it is too large
Load Diff
|
@ -5,6 +5,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||||
|
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||||
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||||
import {
|
import {
|
||||||
SSEClientTransport,
|
SSEClientTransport,
|
||||||
|
@ -17,7 +18,7 @@ import {
|
||||||
import { parse } from 'shell-quote';
|
import { parse } from 'shell-quote';
|
||||||
import { MCPServerConfig } from '../config/config.js';
|
import { MCPServerConfig } from '../config/config.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
import { Type, mcpToTool } from '@google/genai';
|
import { FunctionDeclaration, Type, mcpToTool } from '@google/genai';
|
||||||
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
|
import { sanitizeParameters, ToolRegistry } from './tool-registry.js';
|
||||||
|
|
||||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||||
|
@ -123,16 +124,46 @@ export function getMCPDiscoveryState(): MCPDiscoveryState {
|
||||||
return mcpDiscoveryState;
|
return mcpDiscoveryState;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discovers tools from all configured MCP servers and registers them with the tool registry.
|
||||||
|
* It orchestrates the connection and discovery process for each server defined in the
|
||||||
|
* configuration, as well as any server specified via a command-line argument.
|
||||||
|
*
|
||||||
|
* @param mcpServers A record of named MCP server configurations.
|
||||||
|
* @param mcpServerCommand An optional command string for a dynamically specified MCP server.
|
||||||
|
* @param toolRegistry The central registry where discovered tools will be registered.
|
||||||
|
* @returns A promise that resolves when the discovery process has been attempted for all servers.
|
||||||
|
*/
|
||||||
export async function discoverMcpTools(
|
export async function discoverMcpTools(
|
||||||
mcpServers: Record<string, MCPServerConfig>,
|
mcpServers: Record<string, MCPServerConfig>,
|
||||||
mcpServerCommand: string | undefined,
|
mcpServerCommand: string | undefined,
|
||||||
toolRegistry: ToolRegistry,
|
toolRegistry: ToolRegistry,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
// Set discovery state to in progress
|
|
||||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
mcpServers = populateMcpServerCommand(mcpServers, mcpServerCommand);
|
||||||
|
|
||||||
|
const discoveryPromises = Object.entries(mcpServers).map(
|
||||||
|
([mcpServerName, mcpServerConfig]) =>
|
||||||
|
connectAndDiscover(
|
||||||
|
mcpServerName,
|
||||||
|
mcpServerConfig,
|
||||||
|
toolRegistry,
|
||||||
|
debugMode,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
await Promise.all(discoveryPromises);
|
||||||
|
} finally {
|
||||||
|
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Visible for Testing */
|
||||||
|
export function populateMcpServerCommand(
|
||||||
|
mcpServers: Record<string, MCPServerConfig>,
|
||||||
|
mcpServerCommand: string | undefined,
|
||||||
|
): Record<string, MCPServerConfig> {
|
||||||
if (mcpServerCommand) {
|
if (mcpServerCommand) {
|
||||||
const cmd = mcpServerCommand;
|
const cmd = mcpServerCommand;
|
||||||
const args = parse(cmd, process.env) as string[];
|
const args = parse(cmd, process.env) as string[];
|
||||||
|
@ -145,25 +176,7 @@ export async function discoverMcpTools(
|
||||||
args: args.slice(1),
|
args: args.slice(1),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
return mcpServers;
|
||||||
const discoveryPromises = Object.entries(mcpServers).map(
|
|
||||||
([mcpServerName, mcpServerConfig]) =>
|
|
||||||
connectAndDiscover(
|
|
||||||
mcpServerName,
|
|
||||||
mcpServerConfig,
|
|
||||||
toolRegistry,
|
|
||||||
debugMode,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
await Promise.all(discoveryPromises);
|
|
||||||
|
|
||||||
// Mark discovery as completed
|
|
||||||
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
|
|
||||||
} catch (error) {
|
|
||||||
// Still mark as completed even with errors
|
|
||||||
mcpDiscoveryState = MCPDiscoveryState.COMPLETED;
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -176,71 +189,117 @@ export async function discoverMcpTools(
|
||||||
* @param toolRegistry The registry to register discovered tools with
|
* @param toolRegistry The registry to register discovered tools with
|
||||||
* @returns Promise that resolves when discovery is complete
|
* @returns Promise that resolves when discovery is complete
|
||||||
*/
|
*/
|
||||||
async function connectAndDiscover(
|
export async function connectAndDiscover(
|
||||||
mcpServerName: string,
|
mcpServerName: string,
|
||||||
mcpServerConfig: MCPServerConfig,
|
mcpServerConfig: MCPServerConfig,
|
||||||
toolRegistry: ToolRegistry,
|
toolRegistry: ToolRegistry,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
// Initialize the server status as connecting
|
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||||
|
|
||||||
let transport;
|
try {
|
||||||
if (mcpServerConfig.httpUrl) {
|
const mcpClient = await connectToMcpServer(
|
||||||
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
mcpServerName,
|
||||||
|
mcpServerConfig,
|
||||||
|
debugMode,
|
||||||
|
);
|
||||||
|
try {
|
||||||
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
|
||||||
|
|
||||||
if (mcpServerConfig.headers) {
|
mcpClient.onerror = (error) => {
|
||||||
transportOptions.requestInit = {
|
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
|
||||||
headers: mcpServerConfig.headers,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
transport = new StreamableHTTPClientTransport(
|
|
||||||
new URL(mcpServerConfig.httpUrl),
|
|
||||||
transportOptions,
|
|
||||||
);
|
|
||||||
} else if (mcpServerConfig.url) {
|
|
||||||
const transportOptions: SSEClientTransportOptions = {};
|
|
||||||
if (mcpServerConfig.headers) {
|
|
||||||
transportOptions.requestInit = {
|
|
||||||
headers: mcpServerConfig.headers,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
transport = new SSEClientTransport(
|
|
||||||
new URL(mcpServerConfig.url),
|
|
||||||
transportOptions,
|
|
||||||
);
|
|
||||||
} else if (mcpServerConfig.command) {
|
|
||||||
transport = new StdioClientTransport({
|
|
||||||
command: mcpServerConfig.command,
|
|
||||||
args: mcpServerConfig.args || [],
|
|
||||||
env: {
|
|
||||||
...process.env,
|
|
||||||
...(mcpServerConfig.env || {}),
|
|
||||||
} as Record<string, string>,
|
|
||||||
cwd: mcpServerConfig.cwd,
|
|
||||||
stderr: 'pipe',
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
console.error(
|
|
||||||
`MCP server '${mcpServerName}' has invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio). Skipping.`,
|
|
||||||
);
|
|
||||||
// Update status to disconnected
|
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||||
return;
|
};
|
||||||
|
|
||||||
|
const tools = await discoverTools(
|
||||||
|
mcpServerName,
|
||||||
|
mcpServerConfig,
|
||||||
|
mcpClient,
|
||||||
|
);
|
||||||
|
for (const tool of tools) {
|
||||||
|
toolRegistry.registerTool(tool);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
mcpClient.close();
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error connecting to MCP server '${mcpServerName}':`, error);
|
||||||
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discovers and sanitizes tools from a connected MCP client.
|
||||||
|
* It retrieves function declarations from the client, filters out disabled tools,
|
||||||
|
* generates valid names for them, and wraps them in `DiscoveredMCPTool` instances.
|
||||||
|
*
|
||||||
|
* @param mcpServerName The name of the MCP server.
|
||||||
|
* @param mcpServerConfig The configuration for the MCP server.
|
||||||
|
* @param mcpClient The active MCP client instance.
|
||||||
|
* @returns A promise that resolves to an array of discovered and enabled tools.
|
||||||
|
* @throws An error if no enabled tools are found or if the server provides invalid function declarations.
|
||||||
|
*/
|
||||||
|
export async function discoverTools(
|
||||||
|
mcpServerName: string,
|
||||||
|
mcpServerConfig: MCPServerConfig,
|
||||||
|
mcpClient: Client,
|
||||||
|
): Promise<DiscoveredMCPTool[]> {
|
||||||
|
try {
|
||||||
|
const mcpCallableTool = mcpToTool(mcpClient);
|
||||||
|
const tool = await mcpCallableTool.tool();
|
||||||
|
|
||||||
|
if (!Array.isArray(tool.functionDeclarations)) {
|
||||||
|
throw new Error(`Server did not return valid function declarations.`);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
const discoveredTools: DiscoveredMCPTool[] = [];
|
||||||
debugMode &&
|
for (const funcDecl of tool.functionDeclarations) {
|
||||||
transport instanceof StdioClientTransport &&
|
if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) {
|
||||||
transport.stderr
|
continue;
|
||||||
) {
|
|
||||||
transport.stderr.on('data', (data) => {
|
|
||||||
const stderrStr = data.toString().trim();
|
|
||||||
console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const toolNameForModel = generateValidName(funcDecl, mcpServerName);
|
||||||
|
|
||||||
|
sanitizeParameters(funcDecl.parameters);
|
||||||
|
|
||||||
|
discoveredTools.push(
|
||||||
|
new DiscoveredMCPTool(
|
||||||
|
mcpCallableTool,
|
||||||
|
mcpServerName,
|
||||||
|
toolNameForModel,
|
||||||
|
funcDecl.description ?? '',
|
||||||
|
funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
|
||||||
|
funcDecl.name!,
|
||||||
|
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||||
|
mcpServerConfig.trust,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (discoveredTools.length === 0) {
|
||||||
|
throw Error('No enabled tools found');
|
||||||
|
}
|
||||||
|
return discoveredTools;
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(`Error discovering tools: ${error}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates and connects an MCP client to a server based on the provided configuration.
|
||||||
|
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
|
||||||
|
* establishes a connection. It also applies a patch to handle request timeouts.
|
||||||
|
*
|
||||||
|
* @param mcpServerName The name of the MCP server, used for logging and identification.
|
||||||
|
* @param mcpServerConfig The configuration specifying how to connect to the server.
|
||||||
|
* @returns A promise that resolves to a connected MCP `Client` instance.
|
||||||
|
* @throws An error if the connection fails or the configuration is invalid.
|
||||||
|
*/
|
||||||
|
export async function connectToMcpServer(
|
||||||
|
mcpServerName: string,
|
||||||
|
mcpServerConfig: MCPServerConfig,
|
||||||
|
debugMode: boolean,
|
||||||
|
): Promise<Client> {
|
||||||
const mcpClient = new Client({
|
const mcpClient = new Client({
|
||||||
name: 'gemini-cli-mcp-client',
|
name: 'gemini-cli-mcp-client',
|
||||||
version: '0.0.1',
|
version: '0.0.1',
|
||||||
|
@ -258,12 +317,21 @@ async function connectAndDiscover(
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const transport = createTransport(
|
||||||
|
mcpServerName,
|
||||||
|
mcpServerConfig,
|
||||||
|
debugMode,
|
||||||
|
);
|
||||||
try {
|
try {
|
||||||
await mcpClient.connect(transport, {
|
await mcpClient.connect(transport, {
|
||||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||||
});
|
});
|
||||||
// Connection successful
|
return mcpClient;
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
|
} catch (error) {
|
||||||
|
await transport.close();
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// Create a safe config object that excludes sensitive information
|
// Create a safe config object that excludes sensitive information
|
||||||
const safeConfig = {
|
const safeConfig = {
|
||||||
|
@ -282,131 +350,110 @@ async function connectAndDiscover(
|
||||||
if (process.env.SANDBOX) {
|
if (process.env.SANDBOX) {
|
||||||
errorString += `\nMake sure it is available in the sandbox`;
|
errorString += `\nMake sure it is available in the sandbox`;
|
||||||
}
|
}
|
||||||
console.error(errorString);
|
throw new Error(errorString);
|
||||||
// Update status to disconnected
|
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mcpClient.onerror = (error) => {
|
/** Visible for Testing */
|
||||||
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
|
export function createTransport(
|
||||||
// Update status to disconnected on error
|
mcpServerName: string,
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
mcpServerConfig: MCPServerConfig,
|
||||||
|
debugMode: boolean,
|
||||||
|
): Transport {
|
||||||
|
if (mcpServerConfig.httpUrl) {
|
||||||
|
const transportOptions: StreamableHTTPClientTransportOptions = {};
|
||||||
|
if (mcpServerConfig.headers) {
|
||||||
|
transportOptions.requestInit = {
|
||||||
|
headers: mcpServerConfig.headers,
|
||||||
};
|
};
|
||||||
|
}
|
||||||
try {
|
return new StreamableHTTPClientTransport(
|
||||||
const mcpCallableTool = mcpToTool(mcpClient);
|
new URL(mcpServerConfig.httpUrl),
|
||||||
const tool = await mcpCallableTool.tool();
|
transportOptions,
|
||||||
|
|
||||||
if (!tool || !Array.isArray(tool.functionDeclarations)) {
|
|
||||||
console.error(
|
|
||||||
`MCP server '${mcpServerName}' did not return valid tool function declarations. Skipping.`,
|
|
||||||
);
|
);
|
||||||
if (
|
|
||||||
transport instanceof StdioClientTransport ||
|
|
||||||
transport instanceof SSEClientTransport ||
|
|
||||||
transport instanceof StreamableHTTPClientTransport
|
|
||||||
) {
|
|
||||||
await transport.close();
|
|
||||||
}
|
|
||||||
// Update status to disconnected
|
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const funcDecl of tool.functionDeclarations) {
|
if (mcpServerConfig.url) {
|
||||||
|
const transportOptions: SSEClientTransportOptions = {};
|
||||||
|
if (mcpServerConfig.headers) {
|
||||||
|
transportOptions.requestInit = {
|
||||||
|
headers: mcpServerConfig.headers,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return new SSEClientTransport(
|
||||||
|
new URL(mcpServerConfig.url),
|
||||||
|
transportOptions,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mcpServerConfig.command) {
|
||||||
|
const transport = new StdioClientTransport({
|
||||||
|
command: mcpServerConfig.command,
|
||||||
|
args: mcpServerConfig.args || [],
|
||||||
|
env: {
|
||||||
|
...process.env,
|
||||||
|
...(mcpServerConfig.env || {}),
|
||||||
|
} as Record<string, string>,
|
||||||
|
cwd: mcpServerConfig.cwd,
|
||||||
|
stderr: 'pipe',
|
||||||
|
});
|
||||||
|
if (debugMode) {
|
||||||
|
transport.stderr!.on('data', (data) => {
|
||||||
|
const stderrStr = data.toString().trim();
|
||||||
|
console.debug(`[DEBUG] [MCP STDERR (${mcpServerName})]: `, stderrStr);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return transport;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(
|
||||||
|
`Invalid configuration: missing httpUrl (for Streamable HTTP), url (for SSE), and command (for stdio).`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Visible for testing */
|
||||||
|
export function generateValidName(
|
||||||
|
funcDecl: FunctionDeclaration,
|
||||||
|
mcpServerName: string,
|
||||||
|
) {
|
||||||
|
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
|
||||||
|
let validToolname = funcDecl.name!.replace(/[^a-zA-Z0-9_.-]/g, '_');
|
||||||
|
|
||||||
|
// Prepend MCP server name to avoid conflicts with other tools
|
||||||
|
validToolname = mcpServerName + '__' + validToolname;
|
||||||
|
|
||||||
|
// If longer than 63 characters, replace middle with '___'
|
||||||
|
// (Gemini API says max length 64, but actual limit seems to be 63)
|
||||||
|
if (validToolname.length > 63) {
|
||||||
|
validToolname =
|
||||||
|
validToolname.slice(0, 28) + '___' + validToolname.slice(-32);
|
||||||
|
}
|
||||||
|
return validToolname;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Visible for testing */
|
||||||
|
export function isEnabled(
|
||||||
|
funcDecl: FunctionDeclaration,
|
||||||
|
mcpServerName: string,
|
||||||
|
mcpServerConfig: MCPServerConfig,
|
||||||
|
): boolean {
|
||||||
if (!funcDecl.name) {
|
if (!funcDecl.name) {
|
||||||
console.warn(
|
console.warn(
|
||||||
`Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`,
|
`Discovered a function declaration without a name from MCP server '${mcpServerName}'. Skipping.`,
|
||||||
);
|
);
|
||||||
continue;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { includeTools, excludeTools } = mcpServerConfig;
|
const { includeTools, excludeTools } = mcpServerConfig;
|
||||||
const toolName = funcDecl.name;
|
|
||||||
|
|
||||||
let isEnabled = false;
|
// excludeTools takes precedence over includeTools
|
||||||
if (includeTools === undefined) {
|
if (excludeTools && excludeTools.includes(funcDecl.name)) {
|
||||||
isEnabled = true;
|
return false;
|
||||||
} else {
|
}
|
||||||
isEnabled = includeTools.some(
|
|
||||||
(tool) => tool === toolName || tool.startsWith(`${toolName}(`),
|
return (
|
||||||
|
!includeTools ||
|
||||||
|
includeTools.some(
|
||||||
|
(tool) => tool === funcDecl.name || tool.startsWith(`${funcDecl.name}(`),
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
|
||||||
|
|
||||||
if (excludeTools?.includes(toolName)) {
|
|
||||||
isEnabled = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isEnabled) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let toolNameForModel = funcDecl.name;
|
|
||||||
|
|
||||||
// Replace invalid characters (based on 400 error message from Gemini API) with underscores
|
|
||||||
toolNameForModel = toolNameForModel.replace(/[^a-zA-Z0-9_.-]/g, '_');
|
|
||||||
|
|
||||||
const existingTool = toolRegistry.getTool(toolNameForModel);
|
|
||||||
if (existingTool) {
|
|
||||||
toolNameForModel = mcpServerName + '__' + toolNameForModel;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If longer than 63 characters, replace middle with '___'
|
|
||||||
// (Gemini API says max length 64, but actual limit seems to be 63)
|
|
||||||
if (toolNameForModel.length > 63) {
|
|
||||||
toolNameForModel =
|
|
||||||
toolNameForModel.slice(0, 28) + '___' + toolNameForModel.slice(-32);
|
|
||||||
}
|
|
||||||
|
|
||||||
sanitizeParameters(funcDecl.parameters);
|
|
||||||
|
|
||||||
toolRegistry.registerTool(
|
|
||||||
new DiscoveredMCPTool(
|
|
||||||
mcpCallableTool,
|
|
||||||
mcpServerName,
|
|
||||||
toolNameForModel,
|
|
||||||
funcDecl.description ?? '',
|
|
||||||
funcDecl.parameters ?? { type: Type.OBJECT, properties: {} },
|
|
||||||
funcDecl.name,
|
|
||||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
|
||||||
mcpServerConfig.trust,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error(
|
|
||||||
`Failed to list or register tools for MCP server '${mcpServerName}': ${error}`,
|
|
||||||
);
|
|
||||||
// Ensure transport is cleaned up on error too
|
|
||||||
if (
|
|
||||||
transport instanceof StdioClientTransport ||
|
|
||||||
transport instanceof SSEClientTransport ||
|
|
||||||
transport instanceof StreamableHTTPClientTransport
|
|
||||||
) {
|
|
||||||
await transport.close();
|
|
||||||
}
|
|
||||||
// Update status to disconnected
|
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no tools were registered from this MCP server, the following 'if' block
|
|
||||||
// will close the connection. This is done to conserve resources and prevent
|
|
||||||
// an orphaned connection to a server that isn't providing any usable
|
|
||||||
// functionality. Connections to servers that did provide tools are kept
|
|
||||||
// open, as those tools will require the connection to function.
|
|
||||||
if (toolRegistry.getToolsByServer(mcpServerName).length === 0) {
|
|
||||||
console.log(
|
|
||||||
`No tools registered from MCP server '${mcpServerName}'. Closing connection.`,
|
|
||||||
);
|
|
||||||
if (
|
|
||||||
transport instanceof StdioClientTransport ||
|
|
||||||
transport instanceof SSEClientTransport ||
|
|
||||||
transport instanceof StreamableHTTPClientTransport
|
|
||||||
) {
|
|
||||||
await transport.close();
|
|
||||||
// Update status to disconnected
|
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -326,6 +326,83 @@ describe('ToolRegistry', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('sanitizeParameters', () => {
|
describe('sanitizeParameters', () => {
|
||||||
|
it('should remove default when anyOf is present', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
anyOf: [{ type: Type.STRING }, { type: Type.NUMBER }],
|
||||||
|
default: 'hello',
|
||||||
|
};
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema.default).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should recursively sanitize items in anyOf', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
anyOf: [
|
||||||
|
{
|
||||||
|
anyOf: [{ type: Type.STRING }],
|
||||||
|
default: 'world',
|
||||||
|
},
|
||||||
|
{ type: Type.NUMBER },
|
||||||
|
],
|
||||||
|
};
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema.anyOf![0].default).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should recursively sanitize items in items', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
items: {
|
||||||
|
anyOf: [{ type: Type.STRING }],
|
||||||
|
default: 'world',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema.items!.default).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should recursively sanitize items in properties', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
properties: {
|
||||||
|
prop1: {
|
||||||
|
anyOf: [{ type: Type.STRING }],
|
||||||
|
default: 'world',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema.properties!.prop1.default).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle complex nested schemas', () => {
|
||||||
|
const schema: Schema = {
|
||||||
|
properties: {
|
||||||
|
prop1: {
|
||||||
|
items: {
|
||||||
|
anyOf: [{ type: Type.STRING }],
|
||||||
|
default: 'world',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
prop2: {
|
||||||
|
anyOf: [
|
||||||
|
{
|
||||||
|
properties: {
|
||||||
|
nestedProp: {
|
||||||
|
anyOf: [{ type: Type.NUMBER }],
|
||||||
|
default: 123,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
sanitizeParameters(schema);
|
||||||
|
expect(schema.properties!.prop1.items!.default).toBeUndefined();
|
||||||
|
const nestedProp =
|
||||||
|
schema.properties!.prop2.anyOf![0].properties!.nestedProp;
|
||||||
|
expect(nestedProp?.default).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
it('should remove unsupported format from a simple string property', () => {
|
it('should remove unsupported format from a simple string property', () => {
|
||||||
const schema: Schema = {
|
const schema: Schema = {
|
||||||
type: Type.OBJECT,
|
type: Type.OBJECT,
|
||||||
|
@ -356,25 +433,6 @@ describe('sanitizeParameters', () => {
|
||||||
expect(schema).toEqual(originalSchema);
|
expect(schema).toEqual(originalSchema);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle nested objects recursively', () => {
|
|
||||||
const schema: Schema = {
|
|
||||||
type: Type.OBJECT,
|
|
||||||
properties: {
|
|
||||||
user: {
|
|
||||||
type: Type.OBJECT,
|
|
||||||
properties: {
|
|
||||||
email: { type: Type.STRING, format: 'email' },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
sanitizeParameters(schema);
|
|
||||||
expect(schema.properties?.['user']?.properties?.['email']).toHaveProperty(
|
|
||||||
'format',
|
|
||||||
undefined,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle arrays of objects', () => {
|
it('should handle arrays of objects', () => {
|
||||||
const schema: Schema = {
|
const schema: Schema = {
|
||||||
type: Type.OBJECT,
|
type: Type.OBJECT,
|
||||||
|
@ -414,19 +472,6 @@ describe('sanitizeParameters', () => {
|
||||||
expect(() => sanitizeParameters(undefined)).not.toThrow();
|
expect(() => sanitizeParameters(undefined)).not.toThrow();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should handle cyclic schemas without crashing', () => {
|
|
||||||
const schema: any = {
|
|
||||||
type: Type.OBJECT,
|
|
||||||
properties: {
|
|
||||||
name: { type: Type.STRING, format: 'hostname' },
|
|
||||||
},
|
|
||||||
};
|
|
||||||
schema.properties.self = schema;
|
|
||||||
|
|
||||||
expect(() => sanitizeParameters(schema)).not.toThrow();
|
|
||||||
expect(schema.properties.name).toHaveProperty('format', undefined);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should handle complex nested schemas with cycles', () => {
|
it('should handle complex nested schemas with cycles', () => {
|
||||||
const userNode: any = {
|
const userNode: any = {
|
||||||
type: Type.OBJECT,
|
type: Type.OBJECT,
|
||||||
|
|
Loading…
Reference in New Issue