495 lines
18 KiB
TypeScript
495 lines
18 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2025 Google LLC
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
|
import {
|
|
describe,
|
|
it,
|
|
expect,
|
|
vi,
|
|
beforeEach,
|
|
afterEach,
|
|
Mocked,
|
|
} from 'vitest';
|
|
import { discoverMcpTools } from './mcp-client.js';
|
|
import { Config, MCPServerConfig } from '../config/config.js';
|
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
|
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
|
import { parse, ParseEntry } from 'shell-quote';
|
|
|
|
// Mock dependencies
|
|
vi.mock('shell-quote');
|
|
|
|
vi.mock('@modelcontextprotocol/sdk/client/index.js', () => {
|
|
const MockedClient = vi.fn();
|
|
MockedClient.prototype.connect = vi.fn();
|
|
MockedClient.prototype.listTools = vi.fn();
|
|
// Ensure instances have an onerror property that can be spied on or assigned to
|
|
MockedClient.mockImplementation(() => ({
|
|
connect: MockedClient.prototype.connect,
|
|
listTools: MockedClient.prototype.listTools,
|
|
onerror: vi.fn(), // Each instance gets its own onerror mock
|
|
}));
|
|
return { Client: MockedClient };
|
|
});
|
|
|
|
// Define a global mock for stderr.on that can be cleared and checked
|
|
const mockGlobalStdioStderrOn = vi.fn();
|
|
|
|
vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => {
|
|
// This is the constructor for StdioClientTransport
|
|
const MockedStdioTransport = vi.fn().mockImplementation(function (
|
|
this: any,
|
|
options: any,
|
|
) {
|
|
// Always return a new object with a fresh reference to the global mock for .on
|
|
this.options = options;
|
|
this.stderr = { on: mockGlobalStdioStderrOn };
|
|
this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
|
|
return this;
|
|
});
|
|
return { StdioClientTransport: MockedStdioTransport };
|
|
});
|
|
|
|
vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => {
|
|
const MockedSSETransport = vi.fn().mockImplementation(function (this: any) {
|
|
this.close = vi.fn().mockResolvedValue(undefined); // Add mock close method
|
|
return this;
|
|
});
|
|
return { SSEClientTransport: MockedSSETransport };
|
|
});
|
|
|
|
const mockToolRegistryInstance = {
|
|
registerTool: vi.fn(),
|
|
getToolsByServer: vi.fn().mockReturnValue([]), // Default to empty array
|
|
// Add other methods if they are called by the code under test, with default mocks
|
|
getTool: vi.fn(),
|
|
getAllTools: vi.fn().mockReturnValue([]),
|
|
getFunctionDeclarations: vi.fn().mockReturnValue([]),
|
|
discoverTools: vi.fn().mockResolvedValue(undefined),
|
|
};
|
|
vi.mock('./tool-registry.js', () => ({
|
|
ToolRegistry: vi.fn(() => mockToolRegistryInstance),
|
|
}));
|
|
|
|
describe('discoverMcpTools', () => {
|
|
let mockConfig: Mocked<Config>;
|
|
// Use the instance from the module mock
|
|
let mockToolRegistry: typeof mockToolRegistryInstance;
|
|
|
|
beforeEach(() => {
|
|
// Assign the shared mock instance to the test-scoped variable
|
|
mockToolRegistry = mockToolRegistryInstance;
|
|
// Reset individual spies on the shared instance before each test
|
|
mockToolRegistry.registerTool.mockClear();
|
|
mockToolRegistry.getToolsByServer.mockClear().mockReturnValue([]); // Reset to default
|
|
mockToolRegistry.getTool.mockClear().mockReturnValue(undefined); // Default to no existing tool
|
|
mockToolRegistry.getAllTools.mockClear().mockReturnValue([]);
|
|
mockToolRegistry.getFunctionDeclarations.mockClear().mockReturnValue([]);
|
|
mockToolRegistry.discoverTools.mockClear().mockResolvedValue(undefined);
|
|
|
|
mockConfig = {
|
|
getMcpServers: vi.fn().mockReturnValue({}),
|
|
getMcpServerCommand: vi.fn().mockReturnValue(undefined),
|
|
// getToolRegistry should now return the same shared mock instance
|
|
getToolRegistry: vi.fn(() => mockToolRegistry),
|
|
} as any;
|
|
|
|
vi.mocked(parse).mockClear();
|
|
vi.mocked(Client).mockClear();
|
|
vi.mocked(Client.prototype.connect)
|
|
.mockClear()
|
|
.mockResolvedValue(undefined);
|
|
vi.mocked(Client.prototype.listTools)
|
|
.mockClear()
|
|
.mockResolvedValue({ tools: [] });
|
|
|
|
vi.mocked(StdioClientTransport).mockClear();
|
|
// Ensure the StdioClientTransport mock constructor returns an object with a close method
|
|
vi.mocked(StdioClientTransport).mockImplementation(function (
|
|
this: any,
|
|
options: any,
|
|
) {
|
|
this.options = options;
|
|
this.stderr = { on: mockGlobalStdioStderrOn };
|
|
this.close = vi.fn().mockResolvedValue(undefined);
|
|
return this;
|
|
});
|
|
mockGlobalStdioStderrOn.mockClear(); // Clear the global mock in beforeEach
|
|
|
|
vi.mocked(SSEClientTransport).mockClear();
|
|
// Ensure the SSEClientTransport mock constructor returns an object with a close method
|
|
vi.mocked(SSEClientTransport).mockImplementation(function (this: any) {
|
|
this.close = vi.fn().mockResolvedValue(undefined);
|
|
return this;
|
|
});
|
|
});
|
|
|
|
afterEach(() => {
|
|
vi.restoreAllMocks();
|
|
});
|
|
|
|
it('should do nothing if no MCP servers or command are configured', async () => {
|
|
await discoverMcpTools(mockConfig);
|
|
expect(mockConfig.getMcpServers).toHaveBeenCalledTimes(1);
|
|
expect(mockConfig.getMcpServerCommand).toHaveBeenCalledTimes(1);
|
|
expect(Client).not.toHaveBeenCalled();
|
|
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('should discover tools via mcpServerCommand', async () => {
|
|
const commandString = 'my-mcp-server --start';
|
|
const parsedCommand = ['my-mcp-server', '--start'] as ParseEntry[];
|
|
mockConfig.getMcpServerCommand.mockReturnValue(commandString);
|
|
vi.mocked(parse).mockReturnValue(parsedCommand);
|
|
|
|
const mockTool = {
|
|
name: 'tool1',
|
|
description: 'desc1',
|
|
inputSchema: { type: 'object' as const, properties: {} },
|
|
};
|
|
vi.mocked(Client.prototype.listTools).mockResolvedValue({
|
|
tools: [mockTool],
|
|
});
|
|
|
|
// PRE-MOCK getToolsByServer for the expected server name
|
|
// In this case, listTools fails, so no tools are registered.
|
|
// The default mock `mockReturnValue([])` from beforeEach should apply.
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
expect(parse).toHaveBeenCalledWith(commandString, process.env);
|
|
expect(StdioClientTransport).toHaveBeenCalledWith({
|
|
command: parsedCommand[0],
|
|
args: parsedCommand.slice(1),
|
|
env: expect.any(Object),
|
|
cwd: undefined,
|
|
stderr: 'pipe',
|
|
});
|
|
expect(Client.prototype.connect).toHaveBeenCalledTimes(1);
|
|
expect(Client.prototype.listTools).toHaveBeenCalledTimes(1);
|
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
|
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
|
|
expect.any(DiscoveredMCPTool),
|
|
);
|
|
const registeredTool = mockToolRegistry.registerTool.mock
|
|
.calls[0][0] as DiscoveredMCPTool;
|
|
expect(registeredTool.name).toBe('tool1');
|
|
expect(registeredTool.serverToolName).toBe('tool1');
|
|
});
|
|
|
|
it('should discover tools via mcpServers config (stdio)', async () => {
|
|
const serverConfig: MCPServerConfig = {
|
|
command: './mcp-stdio',
|
|
args: ['arg1'],
|
|
};
|
|
mockConfig.getMcpServers.mockReturnValue({ 'stdio-server': serverConfig });
|
|
|
|
const mockTool = {
|
|
name: 'tool-stdio',
|
|
description: 'desc-stdio',
|
|
inputSchema: { type: 'object' as const, properties: {} },
|
|
};
|
|
vi.mocked(Client.prototype.listTools).mockResolvedValue({
|
|
tools: [mockTool],
|
|
});
|
|
|
|
// PRE-MOCK getToolsByServer for the expected server name
|
|
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
|
expect.any(DiscoveredMCPTool),
|
|
]);
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
expect(StdioClientTransport).toHaveBeenCalledWith({
|
|
command: serverConfig.command,
|
|
args: serverConfig.args,
|
|
env: expect.any(Object),
|
|
cwd: undefined,
|
|
stderr: 'pipe',
|
|
});
|
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
|
|
expect.any(DiscoveredMCPTool),
|
|
);
|
|
const registeredTool = mockToolRegistry.registerTool.mock
|
|
.calls[0][0] as DiscoveredMCPTool;
|
|
expect(registeredTool.name).toBe('tool-stdio');
|
|
});
|
|
|
|
it('should discover tools via mcpServers config (sse)', async () => {
|
|
const serverConfig: MCPServerConfig = { url: 'http://localhost:1234/sse' };
|
|
mockConfig.getMcpServers.mockReturnValue({ 'sse-server': serverConfig });
|
|
|
|
const mockTool = {
|
|
name: 'tool-sse',
|
|
description: 'desc-sse',
|
|
inputSchema: { type: 'object' as const, properties: {} },
|
|
};
|
|
vi.mocked(Client.prototype.listTools).mockResolvedValue({
|
|
tools: [mockTool],
|
|
});
|
|
|
|
// PRE-MOCK getToolsByServer for the expected server name
|
|
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
|
expect.any(DiscoveredMCPTool),
|
|
]);
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
expect(SSEClientTransport).toHaveBeenCalledWith(new URL(serverConfig.url!));
|
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledWith(
|
|
expect.any(DiscoveredMCPTool),
|
|
);
|
|
const registeredTool = mockToolRegistry.registerTool.mock
|
|
.calls[0][0] as DiscoveredMCPTool;
|
|
expect(registeredTool.name).toBe('tool-sse');
|
|
});
|
|
|
|
it('should prefix tool names if multiple MCP servers are configured', async () => {
|
|
const serverConfig1: MCPServerConfig = { command: './mcp1' };
|
|
const serverConfig2: MCPServerConfig = { url: 'http://mcp2/sse' };
|
|
mockConfig.getMcpServers.mockReturnValue({
|
|
server1: serverConfig1,
|
|
server2: serverConfig2,
|
|
});
|
|
|
|
const mockTool1 = {
|
|
name: 'toolA', // Same original name
|
|
description: 'd1',
|
|
inputSchema: { type: 'object' as const, properties: {} },
|
|
};
|
|
const mockTool2 = {
|
|
name: 'toolA', // Same original name
|
|
description: 'd2',
|
|
inputSchema: { type: 'object' as const, properties: {} },
|
|
};
|
|
const mockToolB = {
|
|
name: 'toolB',
|
|
description: 'dB',
|
|
inputSchema: { type: 'object' as const, properties: {} },
|
|
};
|
|
|
|
vi.mocked(Client.prototype.listTools)
|
|
.mockResolvedValueOnce({ tools: [mockTool1, mockToolB] }) // Tools for server1
|
|
.mockResolvedValueOnce({ tools: [mockTool2] }); // Tool for server2 (toolA)
|
|
|
|
const effectivelyRegisteredTools = new Map<string, any>();
|
|
|
|
mockToolRegistry.getTool.mockImplementation((toolName: string) =>
|
|
effectivelyRegisteredTools.get(toolName),
|
|
);
|
|
|
|
// Store the original spy implementation if needed, or just let the new one be the behavior.
|
|
// The mockToolRegistry.registerTool is already a vi.fn() from mockToolRegistryInstance.
|
|
// We are setting its behavior for this test.
|
|
mockToolRegistry.registerTool.mockImplementation((toolToRegister: any) => {
|
|
// Simulate the actual registration name being stored for getTool to find
|
|
effectivelyRegisteredTools.set(toolToRegister.name, toolToRegister);
|
|
// If it's the first time toolA is registered (from server1, not prefixed),
|
|
// also make it findable by its original name for the prefixing check of server2/toolA.
|
|
if (
|
|
toolToRegister.serverName === 'server1' &&
|
|
toolToRegister.serverToolName === 'toolA' &&
|
|
toolToRegister.name === 'toolA'
|
|
) {
|
|
effectivelyRegisteredTools.set('toolA', toolToRegister);
|
|
}
|
|
// The spy call count is inherently tracked by mockToolRegistry.registerTool itself.
|
|
});
|
|
|
|
// PRE-MOCK getToolsByServer for the expected server names
|
|
// This is for the final check in connectAndDiscover to see if any tools were registered *from that server*
|
|
mockToolRegistry.getToolsByServer.mockImplementation(
|
|
(serverName: string) => {
|
|
if (serverName === 'server1')
|
|
return [
|
|
expect.objectContaining({ name: 'toolA' }),
|
|
expect.objectContaining({ name: 'toolB' }),
|
|
];
|
|
if (serverName === 'server2')
|
|
return [expect.objectContaining({ name: 'server2__toolA' })];
|
|
return [];
|
|
},
|
|
);
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(3);
|
|
const registeredArgs = mockToolRegistry.registerTool.mock.calls.map(
|
|
(call) => call[0],
|
|
) as DiscoveredMCPTool[];
|
|
|
|
// The order of server processing by Promise.all is not guaranteed.
|
|
// One 'toolA' will be unprefixed, the other will be prefixed.
|
|
const toolA_from_server1 = registeredArgs.find(
|
|
(t) => t.serverToolName === 'toolA' && t.serverName === 'server1',
|
|
);
|
|
const toolA_from_server2 = registeredArgs.find(
|
|
(t) => t.serverToolName === 'toolA' && t.serverName === 'server2',
|
|
);
|
|
const toolB_from_server1 = registeredArgs.find(
|
|
(t) => t.serverToolName === 'toolB' && t.serverName === 'server1',
|
|
);
|
|
|
|
expect(toolA_from_server1).toBeDefined();
|
|
expect(toolA_from_server2).toBeDefined();
|
|
expect(toolB_from_server1).toBeDefined();
|
|
|
|
expect(toolB_from_server1?.name).toBe('toolB'); // toolB is unique
|
|
|
|
// Check that one of toolA is prefixed and the other is not, and the prefixed one is correct.
|
|
if (toolA_from_server1?.name === 'toolA') {
|
|
expect(toolA_from_server2?.name).toBe('server2__toolA');
|
|
} else {
|
|
expect(toolA_from_server1?.name).toBe('server1__toolA');
|
|
expect(toolA_from_server2?.name).toBe('toolA');
|
|
}
|
|
});
|
|
|
|
it('should clean schema properties ($schema, additionalProperties)', async () => {
|
|
const serverConfig: MCPServerConfig = { command: './mcp-clean' };
|
|
mockConfig.getMcpServers.mockReturnValue({ 'clean-server': serverConfig });
|
|
|
|
const rawSchema = {
|
|
type: 'object' as const,
|
|
$schema: 'http://json-schema.org/draft-07/schema#',
|
|
additionalProperties: true,
|
|
properties: {
|
|
prop1: { type: 'string', $schema: 'remove-this' },
|
|
prop2: {
|
|
type: 'object' as const,
|
|
additionalProperties: false,
|
|
properties: { nested: { type: 'number' } },
|
|
},
|
|
},
|
|
};
|
|
const mockTool = {
|
|
name: 'cleanTool',
|
|
description: 'd',
|
|
inputSchema: JSON.parse(JSON.stringify(rawSchema)),
|
|
};
|
|
vi.mocked(Client.prototype.listTools).mockResolvedValue({
|
|
tools: [mockTool],
|
|
});
|
|
// PRE-MOCK getToolsByServer for the expected server name
|
|
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
|
expect.any(DiscoveredMCPTool),
|
|
]);
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
expect(mockToolRegistry.registerTool).toHaveBeenCalledTimes(1);
|
|
const registeredTool = mockToolRegistry.registerTool.mock
|
|
.calls[0][0] as DiscoveredMCPTool;
|
|
const cleanedParams = registeredTool.schema.parameters as any;
|
|
|
|
expect(cleanedParams).not.toHaveProperty('$schema');
|
|
expect(cleanedParams).not.toHaveProperty('additionalProperties');
|
|
expect(cleanedParams.properties.prop1).not.toHaveProperty('$schema');
|
|
expect(cleanedParams.properties.prop2).not.toHaveProperty(
|
|
'additionalProperties',
|
|
);
|
|
expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty(
|
|
'$schema',
|
|
);
|
|
expect(cleanedParams.properties.prop2.properties.nested).not.toHaveProperty(
|
|
'additionalProperties',
|
|
);
|
|
});
|
|
|
|
it('should handle error if mcpServerCommand parsing fails', async () => {
|
|
const commandString = 'my-mcp-server "unterminated quote';
|
|
mockConfig.getMcpServerCommand.mockReturnValue(commandString);
|
|
vi.mocked(parse).mockImplementation(() => {
|
|
throw new Error('Parsing failed');
|
|
});
|
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
|
|
|
await expect(discoverMcpTools(mockConfig)).rejects.toThrow(
|
|
'Parsing failed',
|
|
);
|
|
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
|
|
expect(console.error).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('should log error and skip server if config is invalid (missing url and command)', async () => {
|
|
mockConfig.getMcpServers.mockReturnValue({ 'bad-server': {} as any });
|
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
expect(console.error).toHaveBeenCalledWith(
|
|
expect.stringContaining(
|
|
"MCP server 'bad-server' has invalid configuration",
|
|
),
|
|
);
|
|
// Client constructor should not be called if config is invalid before instantiation
|
|
expect(Client).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('should log error and skip server if mcpClient.connect fails', async () => {
|
|
const serverConfig: MCPServerConfig = { command: './mcp-fail-connect' };
|
|
mockConfig.getMcpServers.mockReturnValue({
|
|
'fail-connect-server': serverConfig,
|
|
});
|
|
vi.mocked(Client.prototype.connect).mockRejectedValue(
|
|
new Error('Connection refused'),
|
|
);
|
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
expect(console.error).toHaveBeenCalledWith(
|
|
expect.stringContaining(
|
|
"failed to start or connect to MCP server 'fail-connect-server'",
|
|
),
|
|
);
|
|
expect(Client.prototype.listTools).not.toHaveBeenCalled();
|
|
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('should log error and skip server if mcpClient.listTools fails', async () => {
|
|
const serverConfig: MCPServerConfig = { command: './mcp-fail-list' };
|
|
mockConfig.getMcpServers.mockReturnValue({
|
|
'fail-list-server': serverConfig,
|
|
});
|
|
vi.mocked(Client.prototype.listTools).mockRejectedValue(
|
|
new Error('ListTools error'),
|
|
);
|
|
vi.spyOn(console, 'error').mockImplementation(() => {});
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
expect(console.error).toHaveBeenCalledWith(
|
|
expect.stringContaining(
|
|
"Failed to list or register tools for MCP server 'fail-list-server'",
|
|
),
|
|
);
|
|
expect(mockToolRegistry.registerTool).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('should assign mcpClient.onerror handler', async () => {
|
|
const serverConfig: MCPServerConfig = { command: './mcp-onerror' };
|
|
mockConfig.getMcpServers.mockReturnValue({
|
|
'onerror-server': serverConfig,
|
|
});
|
|
// PRE-MOCK getToolsByServer for the expected server name
|
|
mockToolRegistry.getToolsByServer.mockReturnValueOnce([
|
|
expect.any(DiscoveredMCPTool),
|
|
]);
|
|
|
|
await discoverMcpTools(mockConfig);
|
|
|
|
const clientInstances = vi.mocked(Client).mock.results;
|
|
expect(clientInstances.length).toBeGreaterThan(0);
|
|
const lastClientInstance =
|
|
clientInstances[clientInstances.length - 1]?.value;
|
|
expect(lastClientInstance?.onerror).toEqual(expect.any(Function));
|
|
});
|
|
});
|