diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index e52cb9df..2b8753a0 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -976,4 +976,84 @@ describe('mcpCommand', () => { } }); }); + + describe('refresh subcommand', () => { + it('should refresh the list of tools and display the status', async () => { + const mockToolRegistry = { + discoverMcpTools: vi.fn(), + getAllTools: vi.fn().mockReturnValue([]), + }; + const mockGeminiClient = { + setTools: vi.fn(), + }; + + const context = createMockCommandContext({ + services: { + config: { + getMcpServers: vi.fn().mockReturnValue({ server1: {} }), + getBlockedMcpServers: vi.fn().mockReturnValue([]), + getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry), + getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), + }, + }, + }); + + const refreshCommand = mcpCommand.subCommands?.find( + (cmd) => cmd.name === 'refresh', + ); + expect(refreshCommand).toBeDefined(); + + const result = await refreshCommand!.action!(context, ''); + + expect(context.ui.addItem).toHaveBeenCalledWith( + { + type: 'info', + text: 'Refreshing MCP servers and tools...', + }, + expect.any(Number), + ); + expect(mockToolRegistry.discoverMcpTools).toHaveBeenCalled(); + expect(mockGeminiClient.setTools).toHaveBeenCalled(); + + expect(isMessageAction(result)).toBe(true); + if (isMessageAction(result)) { + expect(result.messageType).toBe('info'); + expect(result.content).toContain('Configured MCP servers:'); + } + }); + + it('should show an error if config is not available', async () => { + const contextWithoutConfig = createMockCommandContext({ + services: { + config: null, + }, + }); + + const refreshCommand = mcpCommand.subCommands?.find( + (cmd) => cmd.name === 'refresh', + ); + const result = await refreshCommand!.action!(contextWithoutConfig, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Config not loaded.', + }); + }); + + it('should show an error if tool registry is not available', async () => { + mockConfig.getToolRegistry = vi.fn().mockResolvedValue(undefined); + + const refreshCommand = mcpCommand.subCommands?.find( + (cmd) => cmd.name === 'refresh', + ); + const result = await refreshCommand!.action!(mockContext, ''); + + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Could not retrieve tool registry.', + }); + }); + }); }); diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index c33a25d1..5467b994 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -417,12 +417,57 @@ const listCommand: SlashCommand = { }, }; +const refreshCommand: SlashCommand = { + name: 'refresh', + description: 'Refresh the list of MCP servers and tools', + kind: CommandKind.BUILT_IN, + action: async ( + context: CommandContext, + ): Promise => { + const { config } = context.services; + if (!config) { + return { + type: 'message', + messageType: 'error', + content: 'Config not loaded.', + }; + } + + const toolRegistry = await config.getToolRegistry(); + if (!toolRegistry) { + return { + type: 'message', + messageType: 'error', + content: 'Could not retrieve tool registry.', + }; + } + + context.ui.addItem( + { + type: 'info', + text: 'Refreshing MCP servers and tools...', + }, + Date.now(), + ); + + await toolRegistry.discoverMcpTools(); + + // Update the client with the new tools + const geminiClient = config.getGeminiClient(); + if (geminiClient) { + await geminiClient.setTools(); + } + + return getMcpStatus(context, false, false, false); + }, +}; + export const mcpCommand: SlashCommand = { name: 'mcp', description: 'list configured MCP servers and tools, or authenticate with OAuth-enabled servers', kind: CommandKind.BUILT_IN, - subCommands: [listCommand, authCommand], + subCommands: [listCommand, authCommand, refreshCommand], // Default action when no subcommand is provided action: async (context: CommandContext, args: string) => // If no subcommand, run the list command diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 9fec505f..3f0b3db5 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -23,7 +23,7 @@ import { GitService } from '../services/gitService.js'; vi.mock('../tools/tool-registry', () => { const ToolRegistryMock = vi.fn(); ToolRegistryMock.prototype.registerTool = vi.fn(); - ToolRegistryMock.prototype.discoverTools = vi.fn(); + ToolRegistryMock.prototype.discoverAllTools = vi.fn(); ToolRegistryMock.prototype.getAllTools = vi.fn(() => []); // Mock methods if needed ToolRegistryMock.prototype.getTool = vi.fn(); ToolRegistryMock.prototype.getFunctionDeclarations = vi.fn(() => []); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 231bbcd5..485a56c4 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -630,7 +630,7 @@ export class Config { registerCoreTool(MemoryTool); registerCoreTool(WebSearchTool, this); - await registry.discoverTools(); + await registry.discoverAllTools(); return registry; } } diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index ab337252..de355a98 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -312,7 +312,7 @@ describe('ToolRegistry', () => { return mockChildProcess as any; }); - await toolRegistry.discoverTools(); + await toolRegistry.discoverAllTools(); const discoveredTool = toolRegistry.getTool('tool-with-bad-format'); expect(discoveredTool).toBeDefined(); @@ -338,7 +338,7 @@ describe('ToolRegistry', () => { }; vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal); - await toolRegistry.discoverTools(); + await toolRegistry.discoverAllTools(); expect(mockDiscoverMcpTools).toHaveBeenCalledWith( mcpServerConfigVal, @@ -360,7 +360,7 @@ describe('ToolRegistry', () => { }; vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal); - await toolRegistry.discoverTools(); + await toolRegistry.discoverAllTools(); expect(mockDiscoverMcpTools).toHaveBeenCalledWith( mcpServerConfigVal, diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index a6742c06..b72ed9a5 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -153,8 +153,9 @@ export class ToolRegistry { /** * Discovers tools from project (if available and configured). * Can be called multiple times to update discovered tools. + * This will discover tools from the command line and from MCP servers. */ - async discoverTools(): Promise { + async discoverAllTools(): Promise { // remove any previously discovered tools for (const tool of this.tools.values()) { if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) { @@ -173,6 +174,28 @@ export class ToolRegistry { ); } + /** + * Discovers tools from project (if available and configured). + * Can be called multiple times to update discovered tools. + * This will NOT discover tools from the command line, only from MCP servers. + */ + async discoverMcpTools(): Promise { + // remove any previously discovered tools + for (const tool of this.tools.values()) { + if (tool instanceof DiscoveredMCPTool) { + this.tools.delete(tool.name); + } + } + + // discover tools using MCP servers, if configured + await discoverMcpTools( + this.config.getMcpServers() ?? {}, + this.config.getMcpServerCommand(), + this, + this.config.getDebugMode(), + ); + } + /** * Discover or re-discover tools for a single MCP server. * @param serverName - The name of the server to discover tools from.