feat: restart MCP servers on /mcp refresh (#5479)
Co-authored-by: Brian Ray <62354532+emeryray2002@users.noreply.github.com> Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
parent
4828e4daf1
commit
b24c5887c4
|
@ -972,6 +972,7 @@ describe('mcpCommand', () => {
|
||||||
it('should refresh the list of tools and display the status', async () => {
|
it('should refresh the list of tools and display the status', async () => {
|
||||||
const mockToolRegistry = {
|
const mockToolRegistry = {
|
||||||
discoverMcpTools: vi.fn(),
|
discoverMcpTools: vi.fn(),
|
||||||
|
restartMcpServers: vi.fn(),
|
||||||
getAllTools: vi.fn().mockReturnValue([]),
|
getAllTools: vi.fn().mockReturnValue([]),
|
||||||
};
|
};
|
||||||
const mockGeminiClient = {
|
const mockGeminiClient = {
|
||||||
|
@ -1004,11 +1005,11 @@ describe('mcpCommand', () => {
|
||||||
expect(context.ui.addItem).toHaveBeenCalledWith(
|
expect(context.ui.addItem).toHaveBeenCalledWith(
|
||||||
{
|
{
|
||||||
type: 'info',
|
type: 'info',
|
||||||
text: 'Refreshing MCP servers and tools...',
|
text: 'Restarting MCP servers...',
|
||||||
},
|
},
|
||||||
expect.any(Number),
|
expect.any(Number),
|
||||||
);
|
);
|
||||||
expect(mockToolRegistry.discoverMcpTools).toHaveBeenCalled();
|
expect(mockToolRegistry.restartMcpServers).toHaveBeenCalled();
|
||||||
expect(mockGeminiClient.setTools).toHaveBeenCalled();
|
expect(mockGeminiClient.setTools).toHaveBeenCalled();
|
||||||
expect(context.ui.reloadCommands).toHaveBeenCalledTimes(1);
|
expect(context.ui.reloadCommands).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
|
|
@ -471,7 +471,7 @@ const listCommand: SlashCommand = {
|
||||||
|
|
||||||
const refreshCommand: SlashCommand = {
|
const refreshCommand: SlashCommand = {
|
||||||
name: 'refresh',
|
name: 'refresh',
|
||||||
description: 'Refresh the list of MCP servers and tools',
|
description: 'Restarts MCP servers.',
|
||||||
kind: CommandKind.BUILT_IN,
|
kind: CommandKind.BUILT_IN,
|
||||||
action: async (
|
action: async (
|
||||||
context: CommandContext,
|
context: CommandContext,
|
||||||
|
@ -497,12 +497,12 @@ const refreshCommand: SlashCommand = {
|
||||||
context.ui.addItem(
|
context.ui.addItem(
|
||||||
{
|
{
|
||||||
type: 'info',
|
type: 'info',
|
||||||
text: 'Refreshing MCP servers and tools...',
|
text: 'Restarting MCP servers...',
|
||||||
},
|
},
|
||||||
Date.now(),
|
Date.now(),
|
||||||
);
|
);
|
||||||
|
|
||||||
await toolRegistry.discoverMcpTools();
|
await toolRegistry.restartMcpServers();
|
||||||
|
|
||||||
// Update the client with the new tools
|
// Update the client with the new tools
|
||||||
const geminiClient = config.getGeminiClient();
|
const geminiClient = config.getGeminiClient();
|
||||||
|
|
|
@ -63,6 +63,12 @@ describe('handleAtCommand', () => {
|
||||||
isPathWithinWorkspace: () => true,
|
isPathWithinWorkspace: () => true,
|
||||||
getDirectories: () => [testRootDir],
|
getDirectories: () => [testRootDir],
|
||||||
}),
|
}),
|
||||||
|
getMcpServers: () => ({}),
|
||||||
|
getMcpServerCommand: () => undefined,
|
||||||
|
getPromptRegistry: () => ({
|
||||||
|
getPromptsByServer: () => [],
|
||||||
|
}),
|
||||||
|
getDebugMode: () => false,
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
|
|
||||||
const registry = new ToolRegistry(mockConfig);
|
const registry = new ToolRegistry(mockConfig);
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||||
|
import { McpClientManager } from './mcp-client-manager.js';
|
||||||
|
import { McpClient } from './mcp-client.js';
|
||||||
|
import { ToolRegistry } from './tool-registry.js';
|
||||||
|
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
|
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||||
|
|
||||||
|
vi.mock('./mcp-client.js', async () => {
|
||||||
|
const originalModule = await vi.importActual('./mcp-client.js');
|
||||||
|
return {
|
||||||
|
...originalModule,
|
||||||
|
McpClient: vi.fn(),
|
||||||
|
populateMcpServerCommand: vi.fn(() => ({
|
||||||
|
'test-server': {},
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('McpClientManager', () => {
|
||||||
|
afterEach(() => {
|
||||||
|
vi.restoreAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should discover tools from all servers', async () => {
|
||||||
|
const mockedMcpClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
discover: vi.fn(),
|
||||||
|
disconnect: vi.fn(),
|
||||||
|
getStatus: vi.fn(),
|
||||||
|
};
|
||||||
|
vi.mocked(McpClient).mockReturnValue(
|
||||||
|
mockedMcpClient as unknown as McpClient,
|
||||||
|
);
|
||||||
|
const manager = new McpClientManager(
|
||||||
|
{
|
||||||
|
'test-server': {},
|
||||||
|
},
|
||||||
|
'',
|
||||||
|
{} as ToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
false,
|
||||||
|
{} as WorkspaceContext,
|
||||||
|
);
|
||||||
|
await manager.discoverAllMcpTools();
|
||||||
|
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||||
|
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||||
|
});
|
||||||
|
});
|
|
@ -0,0 +1,115 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { MCPServerConfig } from '../config/config.js';
|
||||||
|
import { ToolRegistry } from './tool-registry.js';
|
||||||
|
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
|
import {
|
||||||
|
McpClient,
|
||||||
|
MCPDiscoveryState,
|
||||||
|
populateMcpServerCommand,
|
||||||
|
} from './mcp-client.js';
|
||||||
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
|
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Manages the lifecycle of multiple MCP clients, including local child processes.
|
||||||
|
* This class is responsible for starting, stopping, and discovering tools from
|
||||||
|
* a collection of MCP servers defined in the configuration.
|
||||||
|
*/
|
||||||
|
export class McpClientManager {
|
||||||
|
private clients: Map<string, McpClient> = new Map();
|
||||||
|
private readonly mcpServers: Record<string, MCPServerConfig>;
|
||||||
|
private readonly mcpServerCommand: string | undefined;
|
||||||
|
private readonly toolRegistry: ToolRegistry;
|
||||||
|
private readonly promptRegistry: PromptRegistry;
|
||||||
|
private readonly debugMode: boolean;
|
||||||
|
private readonly workspaceContext: WorkspaceContext;
|
||||||
|
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
mcpServers: Record<string, MCPServerConfig>,
|
||||||
|
mcpServerCommand: string | undefined,
|
||||||
|
toolRegistry: ToolRegistry,
|
||||||
|
promptRegistry: PromptRegistry,
|
||||||
|
debugMode: boolean,
|
||||||
|
workspaceContext: WorkspaceContext,
|
||||||
|
) {
|
||||||
|
this.mcpServers = mcpServers;
|
||||||
|
this.mcpServerCommand = mcpServerCommand;
|
||||||
|
this.toolRegistry = toolRegistry;
|
||||||
|
this.promptRegistry = promptRegistry;
|
||||||
|
this.debugMode = debugMode;
|
||||||
|
this.workspaceContext = workspaceContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initiates the tool discovery process for all configured MCP servers.
|
||||||
|
* It connects to each server, discovers its available tools, and registers
|
||||||
|
* them with the `ToolRegistry`.
|
||||||
|
*/
|
||||||
|
async discoverAllMcpTools(): Promise<void> {
|
||||||
|
await this.stop();
|
||||||
|
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||||
|
const servers = populateMcpServerCommand(
|
||||||
|
this.mcpServers,
|
||||||
|
this.mcpServerCommand,
|
||||||
|
);
|
||||||
|
|
||||||
|
const discoveryPromises = Object.entries(servers).map(
|
||||||
|
async ([name, config]) => {
|
||||||
|
const client = new McpClient(
|
||||||
|
name,
|
||||||
|
config,
|
||||||
|
this.toolRegistry,
|
||||||
|
this.promptRegistry,
|
||||||
|
this.workspaceContext,
|
||||||
|
this.debugMode,
|
||||||
|
);
|
||||||
|
this.clients.set(name, client);
|
||||||
|
try {
|
||||||
|
await client.connect();
|
||||||
|
await client.discover();
|
||||||
|
} catch (error) {
|
||||||
|
// Log the error but don't let a single failed server stop the others
|
||||||
|
console.error(
|
||||||
|
`Error during discovery for server '${name}': ${getErrorMessage(
|
||||||
|
error,
|
||||||
|
)}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
await Promise.all(discoveryPromises);
|
||||||
|
this.discoveryState = MCPDiscoveryState.COMPLETED;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stops all running local MCP servers and closes all client connections.
|
||||||
|
* This is the cleanup method to be called on application exit.
|
||||||
|
*/
|
||||||
|
async stop(): Promise<void> {
|
||||||
|
const disconnectionPromises = Array.from(this.clients.entries()).map(
|
||||||
|
async ([name, client]) => {
|
||||||
|
try {
|
||||||
|
await client.disconnect();
|
||||||
|
} catch (error) {
|
||||||
|
console.error(
|
||||||
|
`Error stopping client '${name}': ${getErrorMessage(error)}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
await Promise.all(disconnectionPromises);
|
||||||
|
this.clients.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
getDiscoveryState(): MCPDiscoveryState {
|
||||||
|
return this.discoveryState;
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,16 +4,14 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { afterEach, describe, expect, it, vi, beforeEach } from 'vitest';
|
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||||
import {
|
import {
|
||||||
populateMcpServerCommand,
|
populateMcpServerCommand,
|
||||||
createTransport,
|
createTransport,
|
||||||
isEnabled,
|
isEnabled,
|
||||||
discoverTools,
|
|
||||||
discoverPrompts,
|
|
||||||
hasValidTypes,
|
hasValidTypes,
|
||||||
connectToMcpServer,
|
McpClient,
|
||||||
} from './mcp-client.js';
|
} from './mcp-client.js';
|
||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||||
|
@ -22,26 +20,36 @@ import * as GenAiLib from '@google/genai';
|
||||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||||
import { AuthProviderType } from '../config/config.js';
|
import { AuthProviderType } from '../config/config.js';
|
||||||
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
|
import { ToolRegistry } from './tool-registry.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
|
||||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||||
import { pathToFileURL } from 'node:url';
|
|
||||||
|
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||||
vi.mock('@google/genai');
|
vi.mock('@google/genai');
|
||||||
vi.mock('../mcp/oauth-provider.js');
|
vi.mock('../mcp/oauth-provider.js');
|
||||||
vi.mock('../mcp/oauth-token-storage.js');
|
vi.mock('../mcp/oauth-token-storage.js');
|
||||||
vi.mock('./mcp-tool.js');
|
|
||||||
|
|
||||||
describe('mcp-client', () => {
|
describe('mcp-client', () => {
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
vi.restoreAllMocks();
|
vi.restoreAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('discoverTools', () => {
|
describe('McpClient', () => {
|
||||||
it('should discover tools', async () => {
|
it('should discover tools', async () => {
|
||||||
const mockedClient = {} as unknown as ClientLib.Client;
|
const mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
discover: vi.fn(),
|
||||||
|
disconnect: vi.fn(),
|
||||||
|
getStatus: vi.fn(),
|
||||||
|
registerCapabilities: vi.fn(),
|
||||||
|
setRequestHandler: vi.fn(),
|
||||||
|
};
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||||
tool: () => ({
|
tool: () => ({
|
||||||
functionDeclarations: [
|
functionDeclarations: [
|
||||||
|
@ -51,62 +59,43 @@ describe('mcp-client', () => {
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
} as unknown as GenAiLib.CallableTool);
|
} as unknown as GenAiLib.CallableTool);
|
||||||
|
const mockedToolRegistry = {
|
||||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
registerTool: vi.fn(),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
expect(tools.length).toBe(1);
|
const client = new McpClient(
|
||||||
|
'test-server',
|
||||||
|
{
|
||||||
|
command: 'test-command',
|
||||||
|
},
|
||||||
|
mockedToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
{} as WorkspaceContext,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
await client.connect();
|
||||||
|
await client.discover();
|
||||||
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should log an error if there is an error discovering a tool', async () => {
|
|
||||||
const mockedClient = {} as unknown as ClientLib.Client;
|
|
||||||
const consoleErrorSpy = vi
|
|
||||||
.spyOn(console, 'error')
|
|
||||||
.mockImplementation(() => {});
|
|
||||||
|
|
||||||
const testError = new Error('Invalid tool name');
|
|
||||||
vi.mocked(DiscoveredMCPTool).mockImplementation(
|
|
||||||
(
|
|
||||||
_mcpCallableTool: GenAiLib.CallableTool,
|
|
||||||
_serverName: string,
|
|
||||||
name: string,
|
|
||||||
) => {
|
|
||||||
if (name === 'invalid tool name') {
|
|
||||||
throw testError;
|
|
||||||
}
|
|
||||||
return { name: 'validTool' } as DiscoveredMCPTool;
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
|
||||||
tool: () =>
|
|
||||||
Promise.resolve({
|
|
||||||
functionDeclarations: [
|
|
||||||
{
|
|
||||||
name: 'validTool',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: 'invalid tool name', // this will fail validation
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
} as unknown as GenAiLib.CallableTool);
|
|
||||||
|
|
||||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
|
||||||
|
|
||||||
expect(tools.length).toBe(1);
|
|
||||||
expect(tools[0].name).toBe('validTool');
|
|
||||||
expect(consoleErrorSpy).toHaveBeenCalledOnce();
|
|
||||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
|
||||||
`Error discovering tool: 'invalid tool name' from MCP server 'test-server': ${testError.message}`,
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should skip tools if a parameter is missing a type', async () => {
|
it('should skip tools if a parameter is missing a type', async () => {
|
||||||
const mockedClient = {} as unknown as ClientLib.Client;
|
|
||||||
const consoleWarnSpy = vi
|
const consoleWarnSpy = vi
|
||||||
.spyOn(console, 'warn')
|
.spyOn(console, 'warn')
|
||||||
.mockImplementation(() => {});
|
.mockImplementation(() => {});
|
||||||
|
const mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
discover: vi.fn(),
|
||||||
|
disconnect: vi.fn(),
|
||||||
|
getStatus: vi.fn(),
|
||||||
|
registerCapabilities: vi.fn(),
|
||||||
|
setRequestHandler: vi.fn(),
|
||||||
|
tool: vi.fn(),
|
||||||
|
};
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||||
tool: () =>
|
tool: () =>
|
||||||
Promise.resolve({
|
Promise.resolve({
|
||||||
|
@ -132,352 +121,73 @@ describe('mcp-client', () => {
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
} as unknown as GenAiLib.CallableTool);
|
} as unknown as GenAiLib.CallableTool);
|
||||||
|
const mockedToolRegistry = {
|
||||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
registerTool: vi.fn(),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
expect(tools.length).toBe(1);
|
const client = new McpClient(
|
||||||
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
|
|
||||||
expect(consoleWarnSpy).toHaveBeenCalledOnce();
|
|
||||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
|
||||||
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
|
|
||||||
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
|
|
||||||
);
|
|
||||||
consoleWarnSpy.mockRestore();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should skip tools if a nested parameter is missing a type', async () => {
|
|
||||||
const mockedClient = {} as unknown as ClientLib.Client;
|
|
||||||
const consoleWarnSpy = vi
|
|
||||||
.spyOn(console, 'warn')
|
|
||||||
.mockImplementation(() => {});
|
|
||||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
|
||||||
tool: () =>
|
|
||||||
Promise.resolve({
|
|
||||||
functionDeclarations: [
|
|
||||||
{
|
|
||||||
name: 'invalidTool',
|
|
||||||
parametersJsonSchema: {
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
param1: {
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
nestedParam: {
|
|
||||||
description: 'a nested param with no type',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
} as unknown as GenAiLib.CallableTool);
|
|
||||||
|
|
||||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
|
||||||
|
|
||||||
expect(tools.length).toBe(0);
|
|
||||||
expect(consoleWarnSpy).toHaveBeenCalledOnce();
|
|
||||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
|
||||||
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
|
|
||||||
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
|
|
||||||
);
|
|
||||||
consoleWarnSpy.mockRestore();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should skip tool if an array item is missing a type', async () => {
|
|
||||||
const mockedClient = {} as unknown as ClientLib.Client;
|
|
||||||
const consoleWarnSpy = vi
|
|
||||||
.spyOn(console, 'warn')
|
|
||||||
.mockImplementation(() => {});
|
|
||||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
|
||||||
tool: () =>
|
|
||||||
Promise.resolve({
|
|
||||||
functionDeclarations: [
|
|
||||||
{
|
|
||||||
name: 'invalidTool',
|
|
||||||
parametersJsonSchema: {
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
param1: {
|
|
||||||
type: 'array',
|
|
||||||
items: {
|
|
||||||
description: 'an array item with no type',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
} as unknown as GenAiLib.CallableTool);
|
|
||||||
|
|
||||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
|
||||||
|
|
||||||
expect(tools.length).toBe(0);
|
|
||||||
expect(consoleWarnSpy).toHaveBeenCalledOnce();
|
|
||||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
|
||||||
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
|
|
||||||
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
|
|
||||||
);
|
|
||||||
consoleWarnSpy.mockRestore();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should discover tool with no properties in schema', async () => {
|
|
||||||
const mockedClient = {} as unknown as ClientLib.Client;
|
|
||||||
const consoleWarnSpy = vi
|
|
||||||
.spyOn(console, 'warn')
|
|
||||||
.mockImplementation(() => {});
|
|
||||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
|
||||||
tool: () =>
|
|
||||||
Promise.resolve({
|
|
||||||
functionDeclarations: [
|
|
||||||
{
|
|
||||||
name: 'validTool',
|
|
||||||
parametersJsonSchema: {
|
|
||||||
type: 'object',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
} as unknown as GenAiLib.CallableTool);
|
|
||||||
|
|
||||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
|
||||||
|
|
||||||
expect(tools.length).toBe(1);
|
|
||||||
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
|
|
||||||
expect(consoleWarnSpy).not.toHaveBeenCalled();
|
|
||||||
consoleWarnSpy.mockRestore();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should discover tool with empty properties object in schema', async () => {
|
|
||||||
const mockedClient = {} as unknown as ClientLib.Client;
|
|
||||||
const consoleWarnSpy = vi
|
|
||||||
.spyOn(console, 'warn')
|
|
||||||
.mockImplementation(() => {});
|
|
||||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
|
||||||
tool: () =>
|
|
||||||
Promise.resolve({
|
|
||||||
functionDeclarations: [
|
|
||||||
{
|
|
||||||
name: 'validTool',
|
|
||||||
parametersJsonSchema: {
|
|
||||||
type: 'object',
|
|
||||||
properties: {},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}),
|
|
||||||
} as unknown as GenAiLib.CallableTool);
|
|
||||||
|
|
||||||
const tools = await discoverTools('test-server', {}, mockedClient);
|
|
||||||
|
|
||||||
expect(tools.length).toBe(1);
|
|
||||||
expect(vi.mocked(DiscoveredMCPTool).mock.calls[0][2]).toBe('validTool');
|
|
||||||
expect(consoleWarnSpy).not.toHaveBeenCalled();
|
|
||||||
consoleWarnSpy.mockRestore();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('connectToMcpServer', () => {
|
|
||||||
it('should send a notification when directories change', async () => {
|
|
||||||
const mockedClient = {
|
|
||||||
registerCapabilities: vi.fn(),
|
|
||||||
setRequestHandler: vi.fn(),
|
|
||||||
notification: vi.fn(),
|
|
||||||
callTool: vi.fn(),
|
|
||||||
connect: vi.fn(),
|
|
||||||
};
|
|
||||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
||||||
mockedClient as unknown as ClientLib.Client,
|
|
||||||
);
|
|
||||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
||||||
{} as SdkClientStdioLib.StdioClientTransport,
|
|
||||||
);
|
|
||||||
let onDirectoriesChangedCallback: () => void = () => {};
|
|
||||||
const mockWorkspaceContext = {
|
|
||||||
getDirectories: vi
|
|
||||||
.fn()
|
|
||||||
.mockReturnValue(['/test/dir', '/another/project']),
|
|
||||||
onDirectoriesChanged: vi.fn().mockImplementation((callback) => {
|
|
||||||
onDirectoriesChangedCallback = callback;
|
|
||||||
}),
|
|
||||||
} as unknown as WorkspaceContext;
|
|
||||||
|
|
||||||
await connectToMcpServer(
|
|
||||||
'test-server',
|
'test-server',
|
||||||
{
|
{
|
||||||
command: 'test-command',
|
command: 'test-command',
|
||||||
},
|
},
|
||||||
|
mockedToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
{} as WorkspaceContext,
|
||||||
false,
|
false,
|
||||||
mockWorkspaceContext,
|
|
||||||
);
|
);
|
||||||
|
await client.connect();
|
||||||
onDirectoriesChangedCallback();
|
await client.discover();
|
||||||
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||||
expect(mockedClient.notification).toHaveBeenCalledWith({
|
expect(consoleWarnSpy).toHaveBeenCalledOnce();
|
||||||
method: 'notifications/roots/list_changed',
|
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||||
});
|
`Skipping tool 'invalidTool' from MCP server 'test-server' because it has ` +
|
||||||
|
`missing types in its parameter schema. Please file an issue with the owner of the MCP server.`,
|
||||||
|
);
|
||||||
|
consoleWarnSpy.mockRestore();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should register a roots/list handler', async () => {
|
it('should handle errors when discovering prompts', async () => {
|
||||||
const mockedClient = {
|
|
||||||
registerCapabilities: vi.fn(),
|
|
||||||
setRequestHandler: vi.fn(),
|
|
||||||
callTool: vi.fn(),
|
|
||||||
connect: vi.fn(),
|
|
||||||
};
|
|
||||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
|
||||||
mockedClient as unknown as ClientLib.Client,
|
|
||||||
);
|
|
||||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
|
||||||
{} as SdkClientStdioLib.StdioClientTransport,
|
|
||||||
);
|
|
||||||
const mockWorkspaceContext = {
|
|
||||||
getDirectories: vi
|
|
||||||
.fn()
|
|
||||||
.mockReturnValue(['/test/dir', '/another/project']),
|
|
||||||
onDirectoriesChanged: vi.fn(),
|
|
||||||
} as unknown as WorkspaceContext;
|
|
||||||
|
|
||||||
await connectToMcpServer(
|
|
||||||
'test-server',
|
|
||||||
{
|
|
||||||
command: 'test-command',
|
|
||||||
},
|
|
||||||
false,
|
|
||||||
mockWorkspaceContext,
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(mockedClient.registerCapabilities).toHaveBeenCalledWith({
|
|
||||||
roots: {
|
|
||||||
listChanged: true,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
expect(mockedClient.setRequestHandler).toHaveBeenCalledOnce();
|
|
||||||
const handler = mockedClient.setRequestHandler.mock.calls[0][1];
|
|
||||||
const roots = await handler();
|
|
||||||
expect(roots).toEqual({
|
|
||||||
roots: [
|
|
||||||
{
|
|
||||||
uri: pathToFileURL('/test/dir').toString(),
|
|
||||||
name: 'dir',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
uri: pathToFileURL('/another/project').toString(),
|
|
||||||
name: 'project',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('discoverPrompts', () => {
|
|
||||||
const mockedPromptRegistry = {
|
|
||||||
registerPrompt: vi.fn(),
|
|
||||||
} as unknown as PromptRegistry;
|
|
||||||
|
|
||||||
it('should discover and log prompts', async () => {
|
|
||||||
const mockRequest = vi.fn().mockResolvedValue({
|
|
||||||
prompts: [
|
|
||||||
{ name: 'prompt1', description: 'desc1' },
|
|
||||||
{ name: 'prompt2' },
|
|
||||||
],
|
|
||||||
});
|
|
||||||
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
|
||||||
prompts: {},
|
|
||||||
});
|
|
||||||
const mockedClient = {
|
|
||||||
getServerCapabilities: mockGetServerCapabilities,
|
|
||||||
request: mockRequest,
|
|
||||||
} as unknown as ClientLib.Client;
|
|
||||||
|
|
||||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
|
||||||
|
|
||||||
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
|
||||||
expect(mockRequest).toHaveBeenCalledWith(
|
|
||||||
{ method: 'prompts/list', params: {} },
|
|
||||||
expect.anything(),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should do nothing if no prompts are discovered', async () => {
|
|
||||||
const mockRequest = vi.fn().mockResolvedValue({
|
|
||||||
prompts: [],
|
|
||||||
});
|
|
||||||
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
|
||||||
prompts: {},
|
|
||||||
});
|
|
||||||
|
|
||||||
const mockedClient = {
|
|
||||||
getServerCapabilities: mockGetServerCapabilities,
|
|
||||||
request: mockRequest,
|
|
||||||
} as unknown as ClientLib.Client;
|
|
||||||
|
|
||||||
const consoleLogSpy = vi
|
|
||||||
.spyOn(console, 'debug')
|
|
||||||
.mockImplementation(() => {});
|
|
||||||
|
|
||||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
|
||||||
|
|
||||||
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
|
||||||
expect(mockRequest).toHaveBeenCalledOnce();
|
|
||||||
expect(consoleLogSpy).not.toHaveBeenCalled();
|
|
||||||
|
|
||||||
consoleLogSpy.mockRestore();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should do nothing if the server has no prompt support', async () => {
|
|
||||||
const mockRequest = vi.fn().mockResolvedValue({
|
|
||||||
prompts: [],
|
|
||||||
});
|
|
||||||
const mockGetServerCapabilities = vi.fn().mockReturnValue({});
|
|
||||||
|
|
||||||
const mockedClient = {
|
|
||||||
getServerCapabilities: mockGetServerCapabilities,
|
|
||||||
request: mockRequest,
|
|
||||||
} as unknown as ClientLib.Client;
|
|
||||||
|
|
||||||
const consoleLogSpy = vi
|
|
||||||
.spyOn(console, 'debug')
|
|
||||||
.mockImplementation(() => {});
|
|
||||||
|
|
||||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
|
||||||
|
|
||||||
expect(mockGetServerCapabilities).toHaveBeenCalledOnce();
|
|
||||||
expect(mockRequest).not.toHaveBeenCalled();
|
|
||||||
expect(consoleLogSpy).not.toHaveBeenCalled();
|
|
||||||
|
|
||||||
consoleLogSpy.mockRestore();
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should log an error if discovery fails', async () => {
|
|
||||||
const testError = new Error('test error');
|
|
||||||
testError.message = 'test error';
|
|
||||||
const mockRequest = vi.fn().mockRejectedValue(testError);
|
|
||||||
const mockGetServerCapabilities = vi.fn().mockReturnValue({
|
|
||||||
prompts: {},
|
|
||||||
});
|
|
||||||
const mockedClient = {
|
|
||||||
getServerCapabilities: mockGetServerCapabilities,
|
|
||||||
request: mockRequest,
|
|
||||||
} as unknown as ClientLib.Client;
|
|
||||||
|
|
||||||
const consoleErrorSpy = vi
|
const consoleErrorSpy = vi
|
||||||
.spyOn(console, 'error')
|
.spyOn(console, 'error')
|
||||||
.mockImplementation(() => {});
|
.mockImplementation(() => {});
|
||||||
|
const mockedClient = {
|
||||||
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
connect: vi.fn(),
|
||||||
|
discover: vi.fn(),
|
||||||
expect(mockRequest).toHaveBeenCalledOnce();
|
disconnect: vi.fn(),
|
||||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
getStatus: vi.fn(),
|
||||||
`Error discovering prompts from test-server: ${testError.message}`,
|
registerCapabilities: vi.fn(),
|
||||||
|
setRequestHandler: vi.fn(),
|
||||||
|
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
|
||||||
|
request: vi.fn().mockRejectedValue(new Error('Test error')),
|
||||||
|
};
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||||
|
tool: () => Promise.resolve({ functionDeclarations: [] }),
|
||||||
|
} as unknown as GenAiLib.CallableTool);
|
||||||
|
const client = new McpClient(
|
||||||
|
'test-server',
|
||||||
|
{
|
||||||
|
command: 'test-command',
|
||||||
|
},
|
||||||
|
{} as ToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
{} as WorkspaceContext,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
await client.connect();
|
||||||
|
await expect(client.discover()).rejects.toThrow(
|
||||||
|
'No prompts or tools found on the server.',
|
||||||
|
);
|
||||||
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||||
|
`Error discovering prompts from test-server: Test error`,
|
||||||
);
|
);
|
||||||
|
|
||||||
consoleErrorSpy.mockRestore();
|
consoleErrorSpy.mockRestore();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('appendMcpServerCommand', () => {
|
describe('appendMcpServerCommand', () => {
|
||||||
it('should do nothing if no MCP servers or command are configured', () => {
|
it('should do nothing if no MCP servers or command are configured', () => {
|
||||||
const out = populateMcpServerCommand({}, undefined);
|
const out = populateMcpServerCommand({}, undefined);
|
||||||
|
@ -501,17 +211,6 @@ describe('mcp-client', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('createTransport', () => {
|
describe('createTransport', () => {
|
||||||
const originalEnv = process.env;
|
|
||||||
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.resetModules();
|
|
||||||
process.env = {};
|
|
||||||
});
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
process.env = originalEnv;
|
|
||||||
});
|
|
||||||
|
|
||||||
describe('should connect via httpUrl', () => {
|
describe('should connect via httpUrl', () => {
|
||||||
it('without headers', async () => {
|
it('without headers', async () => {
|
||||||
const transport = await createTransport(
|
const transport = await createTransport(
|
||||||
|
@ -601,7 +300,7 @@ describe('mcp-client', () => {
|
||||||
command: 'test-command',
|
command: 'test-command',
|
||||||
args: ['--foo', 'bar'],
|
args: ['--foo', 'bar'],
|
||||||
cwd: 'test/cwd',
|
cwd: 'test/cwd',
|
||||||
env: { FOO: 'bar' },
|
env: { ...process.env, FOO: 'bar' },
|
||||||
stderr: 'pipe',
|
stderr: 'pipe',
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -69,6 +69,134 @@ export enum MCPDiscoveryState {
|
||||||
COMPLETED = 'completed',
|
COMPLETED = 'completed',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A client for a single MCP server.
|
||||||
|
*
|
||||||
|
* This class is responsible for connecting to, discovering tools from, and
|
||||||
|
* managing the state of a single MCP server.
|
||||||
|
*/
|
||||||
|
export class McpClient {
|
||||||
|
private client: Client;
|
||||||
|
private transport: Transport | undefined;
|
||||||
|
private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
|
||||||
|
private isDisconnecting = false;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
private readonly serverName: string,
|
||||||
|
private readonly serverConfig: MCPServerConfig,
|
||||||
|
private readonly toolRegistry: ToolRegistry,
|
||||||
|
private readonly promptRegistry: PromptRegistry,
|
||||||
|
private readonly workspaceContext: WorkspaceContext,
|
||||||
|
private readonly debugMode: boolean,
|
||||||
|
) {
|
||||||
|
this.client = new Client({
|
||||||
|
name: `gemini-cli-mcp-client-${this.serverName}`,
|
||||||
|
version: '0.0.1',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Connects to the MCP server.
|
||||||
|
*/
|
||||||
|
async connect(): Promise<void> {
|
||||||
|
this.isDisconnecting = false;
|
||||||
|
this.updateStatus(MCPServerStatus.CONNECTING);
|
||||||
|
try {
|
||||||
|
this.transport = await this.createTransport();
|
||||||
|
|
||||||
|
this.client.onerror = (error) => {
|
||||||
|
if (this.isDisconnecting) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
console.error(`MCP ERROR (${this.serverName}):`, error.toString());
|
||||||
|
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||||
|
};
|
||||||
|
|
||||||
|
this.client.registerCapabilities({
|
||||||
|
roots: {},
|
||||||
|
});
|
||||||
|
|
||||||
|
this.client.setRequestHandler(ListRootsRequestSchema, async () => {
|
||||||
|
const roots = [];
|
||||||
|
for (const dir of this.workspaceContext.getDirectories()) {
|
||||||
|
roots.push({
|
||||||
|
uri: pathToFileURL(dir).toString(),
|
||||||
|
name: basename(dir),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
roots,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
await this.client.connect(this.transport, {
|
||||||
|
timeout: this.serverConfig.timeout,
|
||||||
|
});
|
||||||
|
|
||||||
|
this.updateStatus(MCPServerStatus.CONNECTED);
|
||||||
|
} catch (error) {
|
||||||
|
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discovers tools and prompts from the MCP server.
|
||||||
|
*/
|
||||||
|
async discover(): Promise<void> {
|
||||||
|
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||||
|
throw new Error('Client is not connected.');
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompts = await this.discoverPrompts();
|
||||||
|
const tools = await this.discoverTools();
|
||||||
|
|
||||||
|
if (prompts.length === 0 && tools.length === 0) {
|
||||||
|
throw new Error('No prompts or tools found on the server.');
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const tool of tools) {
|
||||||
|
this.toolRegistry.registerTool(tool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Disconnects from the MCP server.
|
||||||
|
*/
|
||||||
|
async disconnect(): Promise<void> {
|
||||||
|
this.isDisconnecting = true;
|
||||||
|
if (this.transport) {
|
||||||
|
await this.transport.close();
|
||||||
|
}
|
||||||
|
this.client.close();
|
||||||
|
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the current status of the client.
|
||||||
|
*/
|
||||||
|
getStatus(): MCPServerStatus {
|
||||||
|
return this.status;
|
||||||
|
}
|
||||||
|
|
||||||
|
private updateStatus(status: MCPServerStatus): void {
|
||||||
|
this.status = status;
|
||||||
|
updateMCPServerStatus(this.serverName, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async createTransport(): Promise<Transport> {
|
||||||
|
return createTransport(this.serverName, this.serverConfig, this.debugMode);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async discoverTools(): Promise<DiscoveredMCPTool[]> {
|
||||||
|
return discoverTools(this.serverName, this.serverConfig, this.client);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async discoverPrompts(): Promise<Prompt[]> {
|
||||||
|
return discoverPrompts(this.serverName, this.client, this.promptRegistry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map to track the status of each MCP server within the core package
|
* Map to track the status of each MCP server within the core package
|
||||||
*/
|
*/
|
||||||
|
@ -117,7 +245,7 @@ export function removeMCPStatusChangeListener(
|
||||||
/**
|
/**
|
||||||
* Update the status of an MCP server
|
* Update the status of an MCP server
|
||||||
*/
|
*/
|
||||||
function updateMCPServerStatus(
|
export function updateMCPServerStatus(
|
||||||
serverName: string,
|
serverName: string,
|
||||||
status: MCPServerStatus,
|
status: MCPServerStatus,
|
||||||
): void {
|
): void {
|
||||||
|
|
|
@ -23,15 +23,17 @@ import { spawn } from 'node:child_process';
|
||||||
import fs from 'node:fs';
|
import fs from 'node:fs';
|
||||||
import { MockTool } from '../test-utils/tools.js';
|
import { MockTool } from '../test-utils/tools.js';
|
||||||
|
|
||||||
|
import { McpClientManager } from './mcp-client-manager.js';
|
||||||
|
|
||||||
vi.mock('node:fs');
|
vi.mock('node:fs');
|
||||||
|
|
||||||
// Use vi.hoisted to define the mock function so it can be used in the vi.mock factory
|
|
||||||
const mockDiscoverMcpTools = vi.hoisted(() => vi.fn());
|
|
||||||
|
|
||||||
// Mock ./mcp-client.js to control its behavior within tool-registry tests
|
// Mock ./mcp-client.js to control its behavior within tool-registry tests
|
||||||
vi.mock('./mcp-client.js', () => ({
|
vi.mock('./mcp-client.js', async () => {
|
||||||
discoverMcpTools: mockDiscoverMcpTools,
|
const originalModule = await vi.importActual('./mcp-client.js');
|
||||||
}));
|
return {
|
||||||
|
...originalModule,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
// Mock node:child_process
|
// Mock node:child_process
|
||||||
vi.mock('node:child_process', async () => {
|
vi.mock('node:child_process', async () => {
|
||||||
|
@ -143,7 +145,6 @@ describe('ToolRegistry', () => {
|
||||||
clear: vi.fn(),
|
clear: vi.fn(),
|
||||||
removePromptsByServer: vi.fn(),
|
removePromptsByServer: vi.fn(),
|
||||||
} as any);
|
} as any);
|
||||||
mockDiscoverMcpTools.mockReset().mockResolvedValue(undefined);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
@ -311,6 +312,10 @@ describe('ToolRegistry', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||||
|
const discoverSpy = vi.spyOn(
|
||||||
|
McpClientManager.prototype,
|
||||||
|
'discoverAllMcpTools',
|
||||||
|
);
|
||||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||||
const mcpServerConfigVal = {
|
const mcpServerConfigVal = {
|
||||||
|
@ -324,38 +329,7 @@ describe('ToolRegistry', () => {
|
||||||
|
|
||||||
await toolRegistry.discoverAllTools();
|
await toolRegistry.discoverAllTools();
|
||||||
|
|
||||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
expect(discoverSpy).toHaveBeenCalled();
|
||||||
mcpServerConfigVal,
|
|
||||||
undefined,
|
|
||||||
toolRegistry,
|
|
||||||
config.getPromptRegistry(),
|
|
||||||
false,
|
|
||||||
expect.any(Object),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
|
||||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
|
||||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
|
||||||
const mcpServerConfigVal = {
|
|
||||||
'my-mcp-server': {
|
|
||||||
command: 'mcp-server-cmd',
|
|
||||||
args: ['--port', '1234'],
|
|
||||||
trust: true,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
|
||||||
|
|
||||||
await toolRegistry.discoverAllTools();
|
|
||||||
|
|
||||||
expect(mockDiscoverMcpTools).toHaveBeenCalledWith(
|
|
||||||
mcpServerConfigVal,
|
|
||||||
undefined,
|
|
||||||
toolRegistry,
|
|
||||||
config.getPromptRegistry(),
|
|
||||||
false,
|
|
||||||
expect.any(Object),
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -16,7 +16,8 @@ import {
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import { spawn } from 'node:child_process';
|
import { spawn } from 'node:child_process';
|
||||||
import { StringDecoder } from 'node:string_decoder';
|
import { StringDecoder } from 'node:string_decoder';
|
||||||
import { discoverMcpTools } from './mcp-client.js';
|
import { connectAndDiscover } from './mcp-client.js';
|
||||||
|
import { McpClientManager } from './mcp-client-manager.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
import { parse } from 'shell-quote';
|
import { parse } from 'shell-quote';
|
||||||
|
|
||||||
|
@ -163,9 +164,18 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||||
export class ToolRegistry {
|
export class ToolRegistry {
|
||||||
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
||||||
private config: Config;
|
private config: Config;
|
||||||
|
private mcpClientManager: McpClientManager;
|
||||||
|
|
||||||
constructor(config: Config) {
|
constructor(config: Config) {
|
||||||
this.config = config;
|
this.config = config;
|
||||||
|
this.mcpClientManager = new McpClientManager(
|
||||||
|
this.config.getMcpServers() ?? {},
|
||||||
|
this.config.getMcpServerCommand(),
|
||||||
|
this,
|
||||||
|
this.config.getPromptRegistry(),
|
||||||
|
this.config.getDebugMode(),
|
||||||
|
this.config.getWorkspaceContext(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -220,14 +230,7 @@ export class ToolRegistry {
|
||||||
await this.discoverAndRegisterToolsFromCommand();
|
await this.discoverAndRegisterToolsFromCommand();
|
||||||
|
|
||||||
// discover tools using MCP servers, if configured
|
// discover tools using MCP servers, if configured
|
||||||
await discoverMcpTools(
|
await this.mcpClientManager.discoverAllMcpTools();
|
||||||
this.config.getMcpServers() ?? {},
|
|
||||||
this.config.getMcpServerCommand(),
|
|
||||||
this,
|
|
||||||
this.config.getPromptRegistry(),
|
|
||||||
this.config.getDebugMode(),
|
|
||||||
this.config.getWorkspaceContext(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -242,14 +245,14 @@ export class ToolRegistry {
|
||||||
this.config.getPromptRegistry().clear();
|
this.config.getPromptRegistry().clear();
|
||||||
|
|
||||||
// discover tools using MCP servers, if configured
|
// discover tools using MCP servers, if configured
|
||||||
await discoverMcpTools(
|
await this.mcpClientManager.discoverAllMcpTools();
|
||||||
this.config.getMcpServers() ?? {},
|
}
|
||||||
this.config.getMcpServerCommand(),
|
|
||||||
this,
|
/**
|
||||||
this.config.getPromptRegistry(),
|
* Restarts all MCP servers and re-discovers tools.
|
||||||
this.config.getDebugMode(),
|
*/
|
||||||
this.config.getWorkspaceContext(),
|
async restartMcpServers(): Promise<void> {
|
||||||
);
|
await this.discoverMcpTools();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -269,9 +272,9 @@ export class ToolRegistry {
|
||||||
const mcpServers = this.config.getMcpServers() ?? {};
|
const mcpServers = this.config.getMcpServers() ?? {};
|
||||||
const serverConfig = mcpServers[serverName];
|
const serverConfig = mcpServers[serverName];
|
||||||
if (serverConfig) {
|
if (serverConfig) {
|
||||||
await discoverMcpTools(
|
await connectAndDiscover(
|
||||||
{ [serverName]: serverConfig },
|
serverName,
|
||||||
undefined,
|
serverConfig,
|
||||||
this,
|
this,
|
||||||
this.config.getPromptRegistry(),
|
this.config.getPromptRegistry(),
|
||||||
this.config.getDebugMode(),
|
this.config.getDebugMode(),
|
||||||
|
|
Loading…
Reference in New Issue