diff --git a/packages/cli/src/services/BuiltinCommandLoader.test.ts b/packages/cli/src/services/BuiltinCommandLoader.test.ts index 0e64b1ac..cd449dd8 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.test.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.test.ts @@ -40,13 +40,19 @@ vi.mock('../ui/commands/extensionsCommand.js', () => ({ extensionsCommand: {}, })); vi.mock('../ui/commands/helpCommand.js', () => ({ helpCommand: {} })); -vi.mock('../ui/commands/mcpCommand.js', () => ({ mcpCommand: {} })); vi.mock('../ui/commands/memoryCommand.js', () => ({ memoryCommand: {} })); vi.mock('../ui/commands/privacyCommand.js', () => ({ privacyCommand: {} })); vi.mock('../ui/commands/quitCommand.js', () => ({ quitCommand: {} })); vi.mock('../ui/commands/statsCommand.js', () => ({ statsCommand: {} })); vi.mock('../ui/commands/themeCommand.js', () => ({ themeCommand: {} })); vi.mock('../ui/commands/toolsCommand.js', () => ({ toolsCommand: {} })); +vi.mock('../ui/commands/mcpCommand.js', () => ({ + mcpCommand: { + name: 'mcp', + description: 'MCP command', + kind: 'BUILT_IN', + }, +})); describe('BuiltinCommandLoader', () => { let mockConfig: Config; @@ -114,5 +120,8 @@ describe('BuiltinCommandLoader', () => { const ideCmd = commands.find((c) => c.name === 'ide'); expect(ideCmd).toBeDefined(); + + const mcpCmd = commands.find((c) => c.name === 'mcp'); + expect(mcpCmd).toBeDefined(); }); }); diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts index 259c6013..58adf5cb 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.ts @@ -58,9 +58,9 @@ export class BuiltinCommandLoader implements ICommandLoader { extensionsCommand, helpCommand, ideCommand(this.config), - mcpCommand, memoryCommand, privacyCommand, + mcpCommand, quitCommand, restoreCommand(this.config), statsCommand, diff --git a/packages/cli/src/services/McpPromptLoader.ts b/packages/cli/src/services/McpPromptLoader.ts new file mode 100644 index 00000000..e912fb3e --- /dev/null +++ b/packages/cli/src/services/McpPromptLoader.ts @@ -0,0 +1,231 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Config, + getErrorMessage, + getMCPServerPrompts, +} from '@google/gemini-cli-core'; +import { + CommandContext, + CommandKind, + SlashCommand, + SlashCommandActionReturn, +} from '../ui/commands/types.js'; +import { ICommandLoader } from './types.js'; +import { PromptArgument } from '@modelcontextprotocol/sdk/types.js'; + +/** + * Discovers and loads executable slash commands from prompts exposed by + * Model-Context-Protocol (MCP) servers. + */ +export class McpPromptLoader implements ICommandLoader { + constructor(private readonly config: Config | null) {} + + /** + * Loads all available prompts from all configured MCP servers and adapts + * them into executable SlashCommand objects. + * + * @param _signal An AbortSignal (unused for this synchronous loader). + * @returns A promise that resolves to an array of loaded SlashCommands. + */ + loadCommands(_signal: AbortSignal): Promise { + const promptCommands: SlashCommand[] = []; + if (!this.config) { + return Promise.resolve([]); + } + const mcpServers = this.config.getMcpServers() || {}; + for (const serverName in mcpServers) { + const prompts = getMCPServerPrompts(this.config, serverName) || []; + for (const prompt of prompts) { + const commandName = `${prompt.name}`; + const newPromptCommand: SlashCommand = { + name: commandName, + description: prompt.description || `Invoke prompt ${prompt.name}`, + kind: CommandKind.MCP_PROMPT, + subCommands: [ + { + name: 'help', + description: 'Show help for this prompt', + kind: CommandKind.MCP_PROMPT, + action: async (): Promise => { + if (!prompt.arguments || prompt.arguments.length === 0) { + return { + type: 'message', + messageType: 'info', + content: `Prompt "${prompt.name}" has no arguments.`, + }; + } + + let helpMessage = `Arguments for "${prompt.name}":\n\n`; + if (prompt.arguments && prompt.arguments.length > 0) { + helpMessage += `You can provide arguments by name (e.g., --argName="value") or by position.\n\n`; + helpMessage += `e.g., ${prompt.name} ${prompt.arguments?.map((_) => `"foo"`)} is equivalent to ${prompt.name} ${prompt.arguments?.map((arg) => `--${arg.name}="foo"`)}\n\n`; + } + for (const arg of prompt.arguments) { + helpMessage += ` --${arg.name}\n`; + if (arg.description) { + helpMessage += ` ${arg.description}\n`; + } + helpMessage += ` (required: ${ + arg.required ? 'yes' : 'no' + })\n\n`; + } + return { + type: 'message', + messageType: 'info', + content: helpMessage, + }; + }, + }, + ], + action: async ( + context: CommandContext, + args: string, + ): Promise => { + if (!this.config) { + return { + type: 'message', + messageType: 'error', + content: 'Config not loaded.', + }; + } + + const promptInputs = this.parseArgs(args, prompt.arguments); + if (promptInputs instanceof Error) { + return { + type: 'message', + messageType: 'error', + content: promptInputs.message, + }; + } + + try { + const mcpServers = this.config.getMcpServers() || {}; + const mcpServerConfig = mcpServers[serverName]; + if (!mcpServerConfig) { + return { + type: 'message', + messageType: 'error', + content: `MCP server config not found for '${serverName}'.`, + }; + } + const result = await prompt.invoke(promptInputs); + + if (result.error) { + return { + type: 'message', + messageType: 'error', + content: `Error invoking prompt: ${result.error}`, + }; + } + + if (!result.messages?.[0]?.content?.text) { + return { + type: 'message', + messageType: 'error', + content: + 'Received an empty or invalid prompt response from the server.', + }; + } + + return { + type: 'submit_prompt', + content: JSON.stringify(result.messages[0].content.text), + }; + } catch (error) { + return { + type: 'message', + messageType: 'error', + content: `Error: ${getErrorMessage(error)}`, + }; + } + }, + completion: async (_: CommandContext, partialArg: string) => { + if (!prompt || !prompt.arguments) { + return []; + } + + const suggestions: string[] = []; + const usedArgNames = new Set( + (partialArg.match(/--([^=]+)/g) || []).map((s) => s.substring(2)), + ); + + for (const arg of prompt.arguments) { + if (!usedArgNames.has(arg.name)) { + suggestions.push(`--${arg.name}=""`); + } + } + + return suggestions; + }, + }; + promptCommands.push(newPromptCommand); + } + } + return Promise.resolve(promptCommands); + } + + private parseArgs( + userArgs: string, + promptArgs: PromptArgument[] | undefined, + ): Record | Error { + const argValues: { [key: string]: string } = {}; + const promptInputs: Record = {}; + + // arg parsing: --key="value" or --key=value + const namedArgRegex = /--([^=]+)=(?:"((?:\\.|[^"\\])*)"|([^ ]*))/g; + let match; + const remainingArgs: string[] = []; + let lastIndex = 0; + + while ((match = namedArgRegex.exec(userArgs)) !== null) { + const key = match[1]; + const value = match[2] ?? match[3]; // Quoted or unquoted value + argValues[key] = value; + // Capture text between matches as potential positional args + if (match.index > lastIndex) { + remainingArgs.push(userArgs.substring(lastIndex, match.index).trim()); + } + lastIndex = namedArgRegex.lastIndex; + } + + // Capture any remaining text after the last named arg + if (lastIndex < userArgs.length) { + remainingArgs.push(userArgs.substring(lastIndex).trim()); + } + + const positionalArgs = remainingArgs.join(' ').split(/ +/); + + if (!promptArgs) { + return promptInputs; + } + for (const arg of promptArgs) { + if (argValues[arg.name]) { + promptInputs[arg.name] = argValues[arg.name]; + } + } + + const unfilledArgs = promptArgs.filter( + (arg) => arg.required && !promptInputs[arg.name], + ); + + const missingArgs: string[] = []; + for (let i = 0; i < unfilledArgs.length; i++) { + if (positionalArgs.length > i && positionalArgs[i]) { + promptInputs[unfilledArgs[i].name] = positionalArgs[i]; + } else { + missingArgs.push(unfilledArgs[i].name); + } + } + + if (missingArgs.length > 0) { + const missingArgNames = missingArgs.map((name) => `--${name}`).join(', '); + return new Error(`Missing required argument(s): ${missingArgNames}`); + } + return promptInputs; + } +} diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index 56093562..903f4b66 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -125,6 +125,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { getToolCallCommand: vi.fn(() => opts.toolCallCommand), getMcpServerCommand: vi.fn(() => opts.mcpServerCommand), getMcpServers: vi.fn(() => opts.mcpServers), + getPromptRegistry: vi.fn(), getExtensions: vi.fn(() => []), getBlockedMcpServers: vi.fn(() => []), getUserAgent: vi.fn(() => opts.userAgent || 'test-agent'), diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index 2b8753a0..afa71ba5 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -71,6 +71,7 @@ describe('mcpCommand', () => { getToolRegistry: ReturnType; getMcpServers: ReturnType; getBlockedMcpServers: ReturnType; + getPromptRegistry: ReturnType; }; beforeEach(() => { @@ -92,6 +93,10 @@ describe('mcpCommand', () => { }), getMcpServers: vi.fn().mockReturnValue({}), getBlockedMcpServers: vi.fn().mockReturnValue([]), + getPromptRegistry: vi.fn().mockResolvedValue({ + getAllPrompts: vi.fn().mockReturnValue([]), + getPromptsByServer: vi.fn().mockReturnValue([]), + }), }; mockContext = createMockCommandContext({ @@ -223,7 +228,7 @@ describe('mcpCommand', () => { // Server 2 - Connected expect(message).toContain( - '🟢 \u001b[1mserver2\u001b[0m - Ready (1 tools)', + '🟢 \u001b[1mserver2\u001b[0m - Ready (1 tool)', ); expect(message).toContain('server2_tool1'); @@ -365,13 +370,13 @@ describe('mcpCommand', () => { if (isMessageAction(result)) { const message = result.content; expect(message).toContain( - '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)', + '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)', ); expect(message).toContain('\u001b[36mserver1_tool1\u001b[0m'); expect(message).toContain( '🔴 \u001b[1mserver2\u001b[0m - Disconnected (0 tools cached)', ); - expect(message).toContain('No tools available'); + expect(message).toContain('No tools or prompts available'); } }); @@ -421,10 +426,10 @@ describe('mcpCommand', () => { // Check server statuses expect(message).toContain( - '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)', + '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)', ); expect(message).toContain( - '🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools will appear when ready)', + '🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools and prompts will appear when ready)', ); } }); @@ -994,6 +999,9 @@ describe('mcpCommand', () => { getBlockedMcpServers: vi.fn().mockReturnValue([]), getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), + getPromptRegistry: vi.fn().mockResolvedValue({ + getPromptsByServer: vi.fn().mockReturnValue([]), + }), }, }, }); diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 5467b994..709053b6 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -12,6 +12,7 @@ import { MessageActionReturn, } from './types.js'; import { + DiscoveredMCPPrompt, DiscoveredMCPTool, getMCPDiscoveryState, getMCPServerStatus, @@ -101,6 +102,8 @@ const getMcpStatus = async ( (tool) => tool instanceof DiscoveredMCPTool && tool.serverName === serverName, ) as DiscoveredMCPTool[]; + const promptRegistry = await config.getPromptRegistry(); + const serverPrompts = promptRegistry.getPromptsByServer(serverName) || []; const status = getMCPServerStatus(serverName); @@ -160,9 +163,26 @@ const getMcpStatus = async ( // Add tool count with conditional messaging if (status === MCPServerStatus.CONNECTED) { - message += ` (${serverTools.length} tools)`; + const parts = []; + if (serverTools.length > 0) { + parts.push( + `${serverTools.length} ${serverTools.length === 1 ? 'tool' : 'tools'}`, + ); + } + if (serverPrompts.length > 0) { + parts.push( + `${serverPrompts.length} ${ + serverPrompts.length === 1 ? 'prompt' : 'prompts' + }`, + ); + } + if (parts.length > 0) { + message += ` (${parts.join(', ')})`; + } else { + message += ` (0 tools)`; + } } else if (status === MCPServerStatus.CONNECTING) { - message += ` (tools will appear when ready)`; + message += ` (tools and prompts will appear when ready)`; } else { message += ` (${serverTools.length} tools cached)`; } @@ -186,6 +206,7 @@ const getMcpStatus = async ( message += RESET_COLOR; if (serverTools.length > 0) { + message += ` ${COLOR_CYAN}Tools:${RESET_COLOR}\n`; serverTools.forEach((tool) => { if (showDescriptions && tool.description) { // Format tool name in cyan using simple ANSI cyan color @@ -222,12 +243,41 @@ const getMcpStatus = async ( } } }); - } else { + } + if (serverPrompts.length > 0) { + if (serverTools.length > 0) { + message += '\n'; + } + message += ` ${COLOR_CYAN}Prompts:${RESET_COLOR}\n`; + serverPrompts.forEach((prompt: DiscoveredMCPPrompt) => { + if (showDescriptions && prompt.description) { + message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}`; + const descLines = prompt.description.trim().split('\n'); + if (descLines) { + message += ':\n'; + for (const descLine of descLines) { + message += ` ${COLOR_GREEN}${descLine}${RESET_COLOR}\n`; + } + } else { + message += '\n'; + } + } else { + message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}\n`; + } + }); + } + + if (serverTools.length === 0 && serverPrompts.length === 0) { + message += ' No tools or prompts available\n'; + } else if (serverTools.length === 0) { message += ' No tools available'; if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) { message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}`; } message += '\n'; + } else if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) { + // This case is for when serverTools.length > 0 + message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}\n`; } message += '\n'; } @@ -328,11 +378,10 @@ const authCommand: SlashCommand = { // Import dynamically to avoid circular dependencies const { MCPOAuthProvider } = await import('@google/gemini-cli-core'); - // Create OAuth config for authentication (will be discovered automatically) - const oauthConfig = server.oauth || { - authorizationUrl: '', // Will be discovered automatically - tokenUrl: '', // Will be discovered automatically - }; + let oauthConfig = server.oauth; + if (!oauthConfig) { + oauthConfig = { enabled: false }; + } // Pass the MCP server URL for OAuth discovery const mcpServerUrl = server.httpUrl || server.url; diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index 9a1088fd..1684677c 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -128,6 +128,7 @@ export type SlashCommandActionReturn = export enum CommandKind { BUILT_IN = 'built-in', FILE = 'file', + MCP_PROMPT = 'mcp-prompt', } // The standardized contract for any command in the system. diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 84eeb033..d308af46 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -28,6 +28,13 @@ vi.mock('../../services/FileCommandLoader.js', () => ({ })), })); +const mockMcpLoadCommands = vi.fn(); +vi.mock('../../services/McpPromptLoader.js', () => ({ + McpPromptLoader: vi.fn().mockImplementation(() => ({ + loadCommands: mockMcpLoadCommands, + })), +})); + vi.mock('../contexts/SessionContext.js', () => ({ useSessionStats: vi.fn(() => ({ stats: {} })), })); @@ -41,6 +48,7 @@ import { LoadedSettings } from '../../config/settings.js'; import { MessageType } from '../types.js'; import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; import { FileCommandLoader } from '../../services/FileCommandLoader.js'; +import { McpPromptLoader } from '../../services/McpPromptLoader.js'; const createTestCommand = ( overrides: Partial, @@ -75,14 +83,17 @@ describe('useSlashCommandProcessor', () => { (vi.mocked(BuiltinCommandLoader) as Mock).mockClear(); mockBuiltinLoadCommands.mockResolvedValue([]); mockFileLoadCommands.mockResolvedValue([]); + mockMcpLoadCommands.mockResolvedValue([]); }); const setupProcessorHook = ( builtinCommands: SlashCommand[] = [], fileCommands: SlashCommand[] = [], + mcpCommands: SlashCommand[] = [], ) => { mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands)); mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands)); + mockMcpLoadCommands.mockResolvedValue(Object.freeze(mcpCommands)); const { result } = renderHook(() => useSlashCommandProcessor( @@ -111,6 +122,7 @@ describe('useSlashCommandProcessor', () => { setupProcessorHook(); expect(BuiltinCommandLoader).toHaveBeenCalledWith(mockConfig); expect(FileCommandLoader).toHaveBeenCalledWith(mockConfig); + expect(McpPromptLoader).toHaveBeenCalledWith(mockConfig); }); it('should call loadCommands and populate state after mounting', async () => { @@ -124,6 +136,7 @@ describe('useSlashCommandProcessor', () => { expect(result.current.slashCommands[0]?.name).toBe('test'); expect(mockBuiltinLoadCommands).toHaveBeenCalledTimes(1); expect(mockFileLoadCommands).toHaveBeenCalledTimes(1); + expect(mockMcpLoadCommands).toHaveBeenCalledTimes(1); }); it('should provide an immutable array of commands to consumers', async () => { @@ -369,6 +382,38 @@ describe('useSlashCommandProcessor', () => { expect.any(Number), ); }); + + it('should handle "submit_prompt" action returned from a mcp-based command', async () => { + const mcpCommand = createTestCommand( + { + name: 'mcpcmd', + description: 'A command from mcp', + action: async () => ({ + type: 'submit_prompt', + content: 'The actual prompt from the mcp command.', + }), + }, + CommandKind.MCP_PROMPT, + ); + + const result = setupProcessorHook([], [], [mcpCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + let actionResult; + await act(async () => { + actionResult = await result.current.handleSlashCommand('/mcpcmd'); + }); + + expect(actionResult).toEqual({ + type: 'submit_prompt', + content: 'The actual prompt from the mcp command.', + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: MessageType.USER, text: '/mcpcmd' }, + expect.any(Number), + ); + }); }); describe('Command Parsing and Matching', () => { @@ -441,6 +486,39 @@ describe('useSlashCommandProcessor', () => { }); describe('Command Precedence', () => { + it('should override mcp-based commands with file-based commands of the same name', async () => { + const mcpAction = vi.fn(); + const fileAction = vi.fn(); + + const mcpCommand = createTestCommand( + { + name: 'override', + description: 'mcp', + action: mcpAction, + }, + CommandKind.MCP_PROMPT, + ); + const fileCommand = createTestCommand( + { name: 'override', description: 'file', action: fileAction }, + CommandKind.FILE, + ); + + const result = setupProcessorHook([], [fileCommand], [mcpCommand]); + + await waitFor(() => { + // The service should only return one command with the name 'override' + expect(result.current.slashCommands).toHaveLength(1); + }); + + await act(async () => { + await result.current.handleSlashCommand('/override'); + }); + + // Only the file-based command's action should be called. + expect(fileAction).toHaveBeenCalledTimes(1); + expect(mcpAction).not.toHaveBeenCalled(); + }); + it('should prioritize a command with a primary name over a command with a matching alias', async () => { const quitAction = vi.fn(); const exitAction = vi.fn(); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index fa2b0b12..9e9dc21c 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -23,6 +23,7 @@ import { type CommandContext, type SlashCommand } from '../commands/types.js'; import { CommandService } from '../../services/CommandService.js'; import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; import { FileCommandLoader } from '../../services/FileCommandLoader.js'; +import { McpPromptLoader } from '../../services/McpPromptLoader.js'; /** * Hook to define and process slash commands (e.g., /help, /clear). @@ -164,6 +165,7 @@ export const useSlashCommandProcessor = ( const controller = new AbortController(); const load = async () => { const loaders = [ + new McpPromptLoader(config), new BuiltinCommandLoader(config), new FileCommandLoader(config), ]; @@ -246,82 +248,95 @@ export const useSlashCommandProcessor = ( args, }, }; - const result = await commandToExecute.action( - fullCommandContext, - args, - ); + try { + const result = await commandToExecute.action( + fullCommandContext, + args, + ); - if (result) { - switch (result.type) { - case 'tool': - return { - type: 'schedule_tool', - toolName: result.toolName, - toolArgs: result.toolArgs, - }; - case 'message': - addItem( - { - type: - result.messageType === 'error' - ? MessageType.ERROR - : MessageType.INFO, - text: result.content, - }, - Date.now(), - ); - return { type: 'handled' }; - case 'dialog': - switch (result.dialog) { - case 'help': - setShowHelp(true); - return { type: 'handled' }; - case 'auth': - openAuthDialog(); - return { type: 'handled' }; - case 'theme': - openThemeDialog(); - return { type: 'handled' }; - case 'editor': - openEditorDialog(); - return { type: 'handled' }; - case 'privacy': - openPrivacyNotice(); - return { type: 'handled' }; - default: { - const unhandled: never = result.dialog; - throw new Error( - `Unhandled slash command result: ${unhandled}`, - ); + if (result) { + switch (result.type) { + case 'tool': + return { + type: 'schedule_tool', + toolName: result.toolName, + toolArgs: result.toolArgs, + }; + case 'message': + addItem( + { + type: + result.messageType === 'error' + ? MessageType.ERROR + : MessageType.INFO, + text: result.content, + }, + Date.now(), + ); + return { type: 'handled' }; + case 'dialog': + switch (result.dialog) { + case 'help': + setShowHelp(true); + return { type: 'handled' }; + case 'auth': + openAuthDialog(); + return { type: 'handled' }; + case 'theme': + openThemeDialog(); + return { type: 'handled' }; + case 'editor': + openEditorDialog(); + return { type: 'handled' }; + case 'privacy': + openPrivacyNotice(); + return { type: 'handled' }; + default: { + const unhandled: never = result.dialog; + throw new Error( + `Unhandled slash command result: ${unhandled}`, + ); + } } + case 'load_history': { + await config + ?.getGeminiClient() + ?.setHistory(result.clientHistory); + fullCommandContext.ui.clear(); + result.history.forEach((item, index) => { + fullCommandContext.ui.addItem(item, index); + }); + return { type: 'handled' }; } - case 'load_history': { - await config - ?.getGeminiClient() - ?.setHistory(result.clientHistory); - fullCommandContext.ui.clear(); - result.history.forEach((item, index) => { - fullCommandContext.ui.addItem(item, index); - }); - return { type: 'handled' }; - } - case 'quit': - setQuittingMessages(result.messages); - setTimeout(() => { - process.exit(0); - }, 100); - return { type: 'handled' }; + case 'quit': + setQuittingMessages(result.messages); + setTimeout(() => { + process.exit(0); + }, 100); + return { type: 'handled' }; - case 'submit_prompt': - return { - type: 'submit_prompt', - content: result.content, - }; - default: { - const unhandled: never = result; - throw new Error(`Unhandled slash command result: ${unhandled}`); + case 'submit_prompt': + return { + type: 'submit_prompt', + content: result.content, + }; + default: { + const unhandled: never = result; + throw new Error( + `Unhandled slash command result: ${unhandled}`, + ); + } } } + } catch (e) { + addItem( + { + type: MessageType.ERROR, + text: e instanceof Error ? e.message : String(e), + }, + Date.now(), + ); + return { type: 'handled' }; } return { type: 'handled' }; diff --git a/packages/cli/src/ui/hooks/useCompletion.test.ts b/packages/cli/src/ui/hooks/useCompletion.test.ts index cd525435..da6a7ab3 100644 --- a/packages/cli/src/ui/hooks/useCompletion.test.ts +++ b/packages/cli/src/ui/hooks/useCompletion.test.ts @@ -1100,7 +1100,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(0); }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/memory'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/memory '); }); it('should append a sub-command when the parent is complete', () => { @@ -1145,7 +1145,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(1); // index 1 is 'add' }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add '); }); it('should complete a command with an alternative name', () => { @@ -1190,7 +1190,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(0); }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/help'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/help '); }); it('should complete a file path', async () => { diff --git a/packages/cli/src/ui/hooks/useCompletion.ts b/packages/cli/src/ui/hooks/useCompletion.ts index dc45222d..10724c21 100644 --- a/packages/cli/src/ui/hooks/useCompletion.ts +++ b/packages/cli/src/ui/hooks/useCompletion.ts @@ -638,10 +638,17 @@ export function useCompletion( // Determine the base path of the command. // - If there's a trailing space, the whole command is the base. // - If it's a known parent path, the whole command is the base. + // - If the last part is a complete argument, the whole command is the base. // - Otherwise, the base is everything EXCEPT the last partial part. + const lastPart = parts.length > 0 ? parts[parts.length - 1] : ''; + const isLastPartACompleteArg = + lastPart.startsWith('--') && lastPart.includes('='); + const basePath = - hasTrailingSpace || isParentPath ? parts : parts.slice(0, -1); - const newValue = `/${[...basePath, suggestion].join(' ')}`; + hasTrailingSpace || isParentPath || isLastPartACompleteArg + ? parts + : parts.slice(0, -1); + const newValue = `/${[...basePath, suggestion].join(' ')} `; buffer.setText(newValue); } else { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 96b6f2cb..7ccfdbc8 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -11,6 +11,7 @@ import { ContentGeneratorConfig, createContentGeneratorConfig, } from '../core/contentGenerator.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; import { ToolRegistry } from '../tools/tool-registry.js'; import { LSTool } from '../tools/ls.js'; import { ReadFileTool } from '../tools/read-file.js'; @@ -186,6 +187,7 @@ export interface ConfigParameters { export class Config { private toolRegistry!: ToolRegistry; + private promptRegistry!: PromptRegistry; private readonly sessionId: string; private contentGeneratorConfig!: ContentGeneratorConfig; private readonly embeddingModel: string; @@ -314,6 +316,7 @@ export class Config { if (this.getCheckpointingEnabled()) { await this.getGitService(); } + this.promptRegistry = new PromptRegistry(); this.toolRegistry = await this.createToolRegistry(); } @@ -396,6 +399,10 @@ export class Config { return Promise.resolve(this.toolRegistry); } + getPromptRegistry(): PromptRegistry { + return this.promptRegistry; + } + getDebugMode(): boolean { return this.debugMode; } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 9d87ce32..829de544 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -49,6 +49,9 @@ export * from './ide/ideContext.js'; export * from './tools/tools.js'; export * from './tools/tool-registry.js'; +// Export prompt logic +export * from './prompts/mcp-prompts.js'; + // Export specific tool logic export * from './tools/read-file.js'; export * from './tools/ls.js'; diff --git a/packages/core/src/prompts/mcp-prompts.ts b/packages/core/src/prompts/mcp-prompts.ts new file mode 100644 index 00000000..7265a023 --- /dev/null +++ b/packages/core/src/prompts/mcp-prompts.ts @@ -0,0 +1,19 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Config } from '../config/config.js'; +import { DiscoveredMCPPrompt } from '../tools/mcp-client.js'; + +export function getMCPServerPrompts( + config: Config, + serverName: string, +): DiscoveredMCPPrompt[] { + const promptRegistry = config.getPromptRegistry(); + if (!promptRegistry) { + return []; + } + return promptRegistry.getPromptsByServer(serverName); +} diff --git a/packages/core/src/prompts/prompt-registry.ts b/packages/core/src/prompts/prompt-registry.ts new file mode 100644 index 00000000..56699130 --- /dev/null +++ b/packages/core/src/prompts/prompt-registry.ts @@ -0,0 +1,56 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { DiscoveredMCPPrompt } from '../tools/mcp-client.js'; + +export class PromptRegistry { + private prompts: Map = new Map(); + + /** + * Registers a prompt definition. + * @param prompt - The prompt object containing schema and execution logic. + */ + registerPrompt(prompt: DiscoveredMCPPrompt): void { + if (this.prompts.has(prompt.name)) { + const newName = `${prompt.serverName}_${prompt.name}`; + console.warn( + `Prompt with name "${prompt.name}" is already registered. Renaming to "${newName}".`, + ); + this.prompts.set(newName, { ...prompt, name: newName }); + } else { + this.prompts.set(prompt.name, prompt); + } + } + + /** + * Returns an array of all registered and discovered prompt instances. + */ + getAllPrompts(): DiscoveredMCPPrompt[] { + return Array.from(this.prompts.values()).sort((a, b) => + a.name.localeCompare(b.name), + ); + } + + /** + * Get the definition of a specific prompt. + */ + getPrompt(name: string): DiscoveredMCPPrompt | undefined { + return this.prompts.get(name); + } + + /** + * Returns an array of prompts registered from a specific MCP server. + */ + getPromptsByServer(serverName: string): DiscoveredMCPPrompt[] { + const serverPrompts: DiscoveredMCPPrompt[] = []; + for (const prompt of this.prompts.values()) { + if (prompt.serverName === serverName) { + serverPrompts.push(prompt); + } + } + return serverPrompts.sort((a, b) => a.name.localeCompare(b.name)); + } +} diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 428c9d2d..4560982c 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -11,6 +11,7 @@ import { createTransport, isEnabled, discoverTools, + discoverPrompts, } from './mcp-client.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -18,6 +19,7 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import * as GenAiLib from '@google/genai'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { AuthProviderType } from '../config/config.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); @@ -50,6 +52,77 @@ describe('mcp-client', () => { }); }); + describe('discoverPrompts', () => { + const mockedPromptRegistry = { + registerPrompt: vi.fn(), + } as unknown as PromptRegistry; + + it('should discover and log prompts', async () => { + const mockRequest = vi.fn().mockResolvedValue({ + prompts: [ + { name: 'prompt1', description: 'desc1' }, + { name: 'prompt2' }, + ], + }); + const mockedClient = { + request: mockRequest, + } as unknown as ClientLib.Client; + + await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); + + expect(mockRequest).toHaveBeenCalledWith( + { method: 'prompts/list', params: {} }, + expect.anything(), + ); + }); + + it('should do nothing if no prompts are discovered', async () => { + const mockRequest = vi.fn().mockResolvedValue({ + prompts: [], + }); + const mockedClient = { + request: mockRequest, + } as unknown as ClientLib.Client; + + const consoleLogSpy = vi + .spyOn(console, 'debug') + .mockImplementation(() => { + // no-op + }); + + await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); + + expect(mockRequest).toHaveBeenCalledOnce(); + expect(consoleLogSpy).not.toHaveBeenCalled(); + + consoleLogSpy.mockRestore(); + }); + + it('should log an error if discovery fails', async () => { + const testError = new Error('test error'); + testError.message = 'test error'; + const mockRequest = vi.fn().mockRejectedValue(testError); + const mockedClient = { + request: mockRequest, + } as unknown as ClientLib.Client; + + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => { + // no-op + }); + + await discoverPrompts('test-server', mockedClient, mockedPromptRegistry); + + expect(mockRequest).toHaveBeenCalledOnce(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `Error discovering prompts from test-server: ${testError.message}`, + ); + + consoleErrorSpy.mockRestore(); + }); + }); + describe('appendMcpServerCommand', () => { it('should do nothing if no MCP servers or command are configured', () => { const out = populateMcpServerCommand({}, undefined); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index c59b1592..d175af1f 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -15,12 +15,20 @@ import { StreamableHTTPClientTransport, StreamableHTTPClientTransportOptions, } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { + Prompt, + ListPromptsResultSchema, + GetPromptResult, + GetPromptResultSchema, +} from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; import { AuthProviderType, MCPServerConfig } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; + import { FunctionDeclaration, mcpToTool } from '@google/genai'; import { ToolRegistry } from './tool-registry.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; @@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js'; export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes +export type DiscoveredMCPPrompt = Prompt & { + serverName: string; + invoke: (params: Record) => Promise; +}; + /** * Enum representing the connection status of an MCP server */ @@ -55,7 +68,7 @@ export enum MCPDiscoveryState { /** * Map to track the status of each MCP server within the core package */ -const mcpServerStatusesInternal: Map = new Map(); +const serverStatuses: Map = new Map(); /** * Track the overall MCP discovery state @@ -104,7 +117,7 @@ function updateMCPServerStatus( serverName: string, status: MCPServerStatus, ): void { - mcpServerStatusesInternal.set(serverName, status); + serverStatuses.set(serverName, status); // Notify all listeners for (const listener of statusChangeListeners) { listener(serverName, status); @@ -115,16 +128,14 @@ function updateMCPServerStatus( * Get the current status of an MCP server */ export function getMCPServerStatus(serverName: string): MCPServerStatus { - return ( - mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED - ); + return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED; } /** * Get all MCP server statuses */ export function getAllMCPServerStatuses(): Map { - return new Map(mcpServerStatusesInternal); + return new Map(serverStatuses); } /** @@ -307,6 +318,7 @@ export async function discoverMcpTools( mcpServers: Record, mcpServerCommand: string | undefined, toolRegistry: ToolRegistry, + promptRegistry: PromptRegistry, debugMode: boolean, ): Promise { mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; @@ -319,6 +331,7 @@ export async function discoverMcpTools( mcpServerName, mcpServerConfig, toolRegistry, + promptRegistry, debugMode, ), ); @@ -362,6 +375,7 @@ export async function connectAndDiscover( mcpServerName: string, mcpServerConfig: MCPServerConfig, toolRegistry: ToolRegistry, + promptRegistry: PromptRegistry, debugMode: boolean, ): Promise { updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); @@ -378,6 +392,7 @@ export async function connectAndDiscover( console.error(`MCP ERROR (${mcpServerName}):`, error.toString()); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); }; + await discoverPrompts(mcpServerName, mcpClient, promptRegistry); const tools = await discoverTools( mcpServerName, @@ -393,7 +408,9 @@ export async function connectAndDiscover( } } catch (error) { console.error( - `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`, + `Error connecting to MCP server '${mcpServerName}': ${getErrorMessage( + error, + )}`, ); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); } @@ -441,15 +458,97 @@ export async function discoverTools( ), ); } - if (discoveredTools.length === 0) { - throw Error('No enabled tools found'); - } return discoveredTools; } catch (error) { throw new Error(`Error discovering tools: ${error}`); } } +/** + * Discovers and logs prompts from a connected MCP client. + * It retrieves prompt declarations from the client and logs their names. + * + * @param mcpServerName The name of the MCP server. + * @param mcpClient The active MCP client instance. + */ +export async function discoverPrompts( + mcpServerName: string, + mcpClient: Client, + promptRegistry: PromptRegistry, +): Promise { + try { + const response = await mcpClient.request( + { method: 'prompts/list', params: {} }, + ListPromptsResultSchema, + ); + + for (const prompt of response.prompts) { + promptRegistry.registerPrompt({ + ...prompt, + serverName: mcpServerName, + invoke: (params: Record) => + invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params), + }); + } + } catch (error) { + // It's okay if this fails, not all servers will have prompts. + // Don't log an error if the method is not found, which is a common case. + if ( + error instanceof Error && + !error.message?.includes('Method not found') + ) { + console.error( + `Error discovering prompts from ${mcpServerName}: ${getErrorMessage( + error, + )}`, + ); + } + } +} + +/** + * Invokes a prompt on a connected MCP client. + * + * @param mcpServerName The name of the MCP server. + * @param mcpClient The active MCP client instance. + * @param promptName The name of the prompt to invoke. + * @param promptParams The parameters to pass to the prompt. + * @returns A promise that resolves to the result of the prompt invocation. + */ +export async function invokeMcpPrompt( + mcpServerName: string, + mcpClient: Client, + promptName: string, + promptParams: Record, +): Promise { + try { + const response = await mcpClient.request( + { + method: 'prompts/get', + params: { + name: promptName, + arguments: promptParams, + }, + }, + GetPromptResultSchema, + ); + + return response; + } catch (error) { + if ( + error instanceof Error && + !error.message?.includes('Method not found') + ) { + console.error( + `Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage( + error, + )}`, + ); + } + throw error; + } +} + /** * Creates and connects an MCP client to a server based on the provided configuration. * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index de355a98..b3fdd7a3 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -344,6 +344,7 @@ describe('ToolRegistry', () => { mcpServerConfigVal, undefined, toolRegistry, + undefined, false, ); }); @@ -366,6 +367,7 @@ describe('ToolRegistry', () => { mcpServerConfigVal, undefined, toolRegistry, + undefined, false, ); }); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index b72ed9a5..57627ee0 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -170,6 +170,7 @@ export class ToolRegistry { this.config.getMcpServers() ?? {}, this.config.getMcpServerCommand(), this, + this.config.getPromptRegistry(), this.config.getDebugMode(), ); } @@ -192,6 +193,7 @@ export class ToolRegistry { this.config.getMcpServers() ?? {}, this.config.getMcpServerCommand(), this, + this.config.getPromptRegistry(), this.config.getDebugMode(), ); } @@ -215,6 +217,7 @@ export class ToolRegistry { { [serverName]: serverConfig }, undefined, this, + this.config.getPromptRegistry(), this.config.getDebugMode(), ); }