Load and use MCP server prompts as slash commands in the CLI (#4828)
Co-authored-by: harold <haroldmciver@google.com> Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
parent
de96887789
commit
eb65034117
|
@ -40,13 +40,19 @@ vi.mock('../ui/commands/extensionsCommand.js', () => ({
|
||||||
extensionsCommand: {},
|
extensionsCommand: {},
|
||||||
}));
|
}));
|
||||||
vi.mock('../ui/commands/helpCommand.js', () => ({ helpCommand: {} }));
|
vi.mock('../ui/commands/helpCommand.js', () => ({ helpCommand: {} }));
|
||||||
vi.mock('../ui/commands/mcpCommand.js', () => ({ mcpCommand: {} }));
|
|
||||||
vi.mock('../ui/commands/memoryCommand.js', () => ({ memoryCommand: {} }));
|
vi.mock('../ui/commands/memoryCommand.js', () => ({ memoryCommand: {} }));
|
||||||
vi.mock('../ui/commands/privacyCommand.js', () => ({ privacyCommand: {} }));
|
vi.mock('../ui/commands/privacyCommand.js', () => ({ privacyCommand: {} }));
|
||||||
vi.mock('../ui/commands/quitCommand.js', () => ({ quitCommand: {} }));
|
vi.mock('../ui/commands/quitCommand.js', () => ({ quitCommand: {} }));
|
||||||
vi.mock('../ui/commands/statsCommand.js', () => ({ statsCommand: {} }));
|
vi.mock('../ui/commands/statsCommand.js', () => ({ statsCommand: {} }));
|
||||||
vi.mock('../ui/commands/themeCommand.js', () => ({ themeCommand: {} }));
|
vi.mock('../ui/commands/themeCommand.js', () => ({ themeCommand: {} }));
|
||||||
vi.mock('../ui/commands/toolsCommand.js', () => ({ toolsCommand: {} }));
|
vi.mock('../ui/commands/toolsCommand.js', () => ({ toolsCommand: {} }));
|
||||||
|
vi.mock('../ui/commands/mcpCommand.js', () => ({
|
||||||
|
mcpCommand: {
|
||||||
|
name: 'mcp',
|
||||||
|
description: 'MCP command',
|
||||||
|
kind: 'BUILT_IN',
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
describe('BuiltinCommandLoader', () => {
|
describe('BuiltinCommandLoader', () => {
|
||||||
let mockConfig: Config;
|
let mockConfig: Config;
|
||||||
|
@ -114,5 +120,8 @@ describe('BuiltinCommandLoader', () => {
|
||||||
|
|
||||||
const ideCmd = commands.find((c) => c.name === 'ide');
|
const ideCmd = commands.find((c) => c.name === 'ide');
|
||||||
expect(ideCmd).toBeDefined();
|
expect(ideCmd).toBeDefined();
|
||||||
|
|
||||||
|
const mcpCmd = commands.find((c) => c.name === 'mcp');
|
||||||
|
expect(mcpCmd).toBeDefined();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -58,9 +58,9 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
||||||
extensionsCommand,
|
extensionsCommand,
|
||||||
helpCommand,
|
helpCommand,
|
||||||
ideCommand(this.config),
|
ideCommand(this.config),
|
||||||
mcpCommand,
|
|
||||||
memoryCommand,
|
memoryCommand,
|
||||||
privacyCommand,
|
privacyCommand,
|
||||||
|
mcpCommand,
|
||||||
quitCommand,
|
quitCommand,
|
||||||
restoreCommand(this.config),
|
restoreCommand(this.config),
|
||||||
statsCommand,
|
statsCommand,
|
||||||
|
|
|
@ -0,0 +1,231 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
Config,
|
||||||
|
getErrorMessage,
|
||||||
|
getMCPServerPrompts,
|
||||||
|
} from '@google/gemini-cli-core';
|
||||||
|
import {
|
||||||
|
CommandContext,
|
||||||
|
CommandKind,
|
||||||
|
SlashCommand,
|
||||||
|
SlashCommandActionReturn,
|
||||||
|
} from '../ui/commands/types.js';
|
||||||
|
import { ICommandLoader } from './types.js';
|
||||||
|
import { PromptArgument } from '@modelcontextprotocol/sdk/types.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discovers and loads executable slash commands from prompts exposed by
|
||||||
|
* Model-Context-Protocol (MCP) servers.
|
||||||
|
*/
|
||||||
|
export class McpPromptLoader implements ICommandLoader {
|
||||||
|
constructor(private readonly config: Config | null) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads all available prompts from all configured MCP servers and adapts
|
||||||
|
* them into executable SlashCommand objects.
|
||||||
|
*
|
||||||
|
* @param _signal An AbortSignal (unused for this synchronous loader).
|
||||||
|
* @returns A promise that resolves to an array of loaded SlashCommands.
|
||||||
|
*/
|
||||||
|
loadCommands(_signal: AbortSignal): Promise<SlashCommand[]> {
|
||||||
|
const promptCommands: SlashCommand[] = [];
|
||||||
|
if (!this.config) {
|
||||||
|
return Promise.resolve([]);
|
||||||
|
}
|
||||||
|
const mcpServers = this.config.getMcpServers() || {};
|
||||||
|
for (const serverName in mcpServers) {
|
||||||
|
const prompts = getMCPServerPrompts(this.config, serverName) || [];
|
||||||
|
for (const prompt of prompts) {
|
||||||
|
const commandName = `${prompt.name}`;
|
||||||
|
const newPromptCommand: SlashCommand = {
|
||||||
|
name: commandName,
|
||||||
|
description: prompt.description || `Invoke prompt ${prompt.name}`,
|
||||||
|
kind: CommandKind.MCP_PROMPT,
|
||||||
|
subCommands: [
|
||||||
|
{
|
||||||
|
name: 'help',
|
||||||
|
description: 'Show help for this prompt',
|
||||||
|
kind: CommandKind.MCP_PROMPT,
|
||||||
|
action: async (): Promise<SlashCommandActionReturn> => {
|
||||||
|
if (!prompt.arguments || prompt.arguments.length === 0) {
|
||||||
|
return {
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'info',
|
||||||
|
content: `Prompt "${prompt.name}" has no arguments.`,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let helpMessage = `Arguments for "${prompt.name}":\n\n`;
|
||||||
|
if (prompt.arguments && prompt.arguments.length > 0) {
|
||||||
|
helpMessage += `You can provide arguments by name (e.g., --argName="value") or by position.\n\n`;
|
||||||
|
helpMessage += `e.g., ${prompt.name} ${prompt.arguments?.map((_) => `"foo"`)} is equivalent to ${prompt.name} ${prompt.arguments?.map((arg) => `--${arg.name}="foo"`)}\n\n`;
|
||||||
|
}
|
||||||
|
for (const arg of prompt.arguments) {
|
||||||
|
helpMessage += ` --${arg.name}\n`;
|
||||||
|
if (arg.description) {
|
||||||
|
helpMessage += ` ${arg.description}\n`;
|
||||||
|
}
|
||||||
|
helpMessage += ` (required: ${
|
||||||
|
arg.required ? 'yes' : 'no'
|
||||||
|
})\n\n`;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'info',
|
||||||
|
content: helpMessage,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
action: async (
|
||||||
|
context: CommandContext,
|
||||||
|
args: string,
|
||||||
|
): Promise<SlashCommandActionReturn> => {
|
||||||
|
if (!this.config) {
|
||||||
|
return {
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'error',
|
||||||
|
content: 'Config not loaded.',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const promptInputs = this.parseArgs(args, prompt.arguments);
|
||||||
|
if (promptInputs instanceof Error) {
|
||||||
|
return {
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'error',
|
||||||
|
content: promptInputs.message,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const mcpServers = this.config.getMcpServers() || {};
|
||||||
|
const mcpServerConfig = mcpServers[serverName];
|
||||||
|
if (!mcpServerConfig) {
|
||||||
|
return {
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'error',
|
||||||
|
content: `MCP server config not found for '${serverName}'.`,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
const result = await prompt.invoke(promptInputs);
|
||||||
|
|
||||||
|
if (result.error) {
|
||||||
|
return {
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'error',
|
||||||
|
content: `Error invoking prompt: ${result.error}`,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!result.messages?.[0]?.content?.text) {
|
||||||
|
return {
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'error',
|
||||||
|
content:
|
||||||
|
'Received an empty or invalid prompt response from the server.',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: 'submit_prompt',
|
||||||
|
content: JSON.stringify(result.messages[0].content.text),
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
return {
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'error',
|
||||||
|
content: `Error: ${getErrorMessage(error)}`,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
completion: async (_: CommandContext, partialArg: string) => {
|
||||||
|
if (!prompt || !prompt.arguments) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const suggestions: string[] = [];
|
||||||
|
const usedArgNames = new Set(
|
||||||
|
(partialArg.match(/--([^=]+)/g) || []).map((s) => s.substring(2)),
|
||||||
|
);
|
||||||
|
|
||||||
|
for (const arg of prompt.arguments) {
|
||||||
|
if (!usedArgNames.has(arg.name)) {
|
||||||
|
suggestions.push(`--${arg.name}=""`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return suggestions;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
promptCommands.push(newPromptCommand);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Promise.resolve(promptCommands);
|
||||||
|
}
|
||||||
|
|
||||||
|
private parseArgs(
|
||||||
|
userArgs: string,
|
||||||
|
promptArgs: PromptArgument[] | undefined,
|
||||||
|
): Record<string, unknown> | Error {
|
||||||
|
const argValues: { [key: string]: string } = {};
|
||||||
|
const promptInputs: Record<string, unknown> = {};
|
||||||
|
|
||||||
|
// arg parsing: --key="value" or --key=value
|
||||||
|
const namedArgRegex = /--([^=]+)=(?:"((?:\\.|[^"\\])*)"|([^ ]*))/g;
|
||||||
|
let match;
|
||||||
|
const remainingArgs: string[] = [];
|
||||||
|
let lastIndex = 0;
|
||||||
|
|
||||||
|
while ((match = namedArgRegex.exec(userArgs)) !== null) {
|
||||||
|
const key = match[1];
|
||||||
|
const value = match[2] ?? match[3]; // Quoted or unquoted value
|
||||||
|
argValues[key] = value;
|
||||||
|
// Capture text between matches as potential positional args
|
||||||
|
if (match.index > lastIndex) {
|
||||||
|
remainingArgs.push(userArgs.substring(lastIndex, match.index).trim());
|
||||||
|
}
|
||||||
|
lastIndex = namedArgRegex.lastIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture any remaining text after the last named arg
|
||||||
|
if (lastIndex < userArgs.length) {
|
||||||
|
remainingArgs.push(userArgs.substring(lastIndex).trim());
|
||||||
|
}
|
||||||
|
|
||||||
|
const positionalArgs = remainingArgs.join(' ').split(/ +/);
|
||||||
|
|
||||||
|
if (!promptArgs) {
|
||||||
|
return promptInputs;
|
||||||
|
}
|
||||||
|
for (const arg of promptArgs) {
|
||||||
|
if (argValues[arg.name]) {
|
||||||
|
promptInputs[arg.name] = argValues[arg.name];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const unfilledArgs = promptArgs.filter(
|
||||||
|
(arg) => arg.required && !promptInputs[arg.name],
|
||||||
|
);
|
||||||
|
|
||||||
|
const missingArgs: string[] = [];
|
||||||
|
for (let i = 0; i < unfilledArgs.length; i++) {
|
||||||
|
if (positionalArgs.length > i && positionalArgs[i]) {
|
||||||
|
promptInputs[unfilledArgs[i].name] = positionalArgs[i];
|
||||||
|
} else {
|
||||||
|
missingArgs.push(unfilledArgs[i].name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (missingArgs.length > 0) {
|
||||||
|
const missingArgNames = missingArgs.map((name) => `--${name}`).join(', ');
|
||||||
|
return new Error(`Missing required argument(s): ${missingArgNames}`);
|
||||||
|
}
|
||||||
|
return promptInputs;
|
||||||
|
}
|
||||||
|
}
|
|
@ -125,6 +125,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||||
getToolCallCommand: vi.fn(() => opts.toolCallCommand),
|
getToolCallCommand: vi.fn(() => opts.toolCallCommand),
|
||||||
getMcpServerCommand: vi.fn(() => opts.mcpServerCommand),
|
getMcpServerCommand: vi.fn(() => opts.mcpServerCommand),
|
||||||
getMcpServers: vi.fn(() => opts.mcpServers),
|
getMcpServers: vi.fn(() => opts.mcpServers),
|
||||||
|
getPromptRegistry: vi.fn(),
|
||||||
getExtensions: vi.fn(() => []),
|
getExtensions: vi.fn(() => []),
|
||||||
getBlockedMcpServers: vi.fn(() => []),
|
getBlockedMcpServers: vi.fn(() => []),
|
||||||
getUserAgent: vi.fn(() => opts.userAgent || 'test-agent'),
|
getUserAgent: vi.fn(() => opts.userAgent || 'test-agent'),
|
||||||
|
|
|
@ -71,6 +71,7 @@ describe('mcpCommand', () => {
|
||||||
getToolRegistry: ReturnType<typeof vi.fn>;
|
getToolRegistry: ReturnType<typeof vi.fn>;
|
||||||
getMcpServers: ReturnType<typeof vi.fn>;
|
getMcpServers: ReturnType<typeof vi.fn>;
|
||||||
getBlockedMcpServers: ReturnType<typeof vi.fn>;
|
getBlockedMcpServers: ReturnType<typeof vi.fn>;
|
||||||
|
getPromptRegistry: ReturnType<typeof vi.fn>;
|
||||||
};
|
};
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
|
@ -92,6 +93,10 @@ describe('mcpCommand', () => {
|
||||||
}),
|
}),
|
||||||
getMcpServers: vi.fn().mockReturnValue({}),
|
getMcpServers: vi.fn().mockReturnValue({}),
|
||||||
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
||||||
|
getPromptRegistry: vi.fn().mockResolvedValue({
|
||||||
|
getAllPrompts: vi.fn().mockReturnValue([]),
|
||||||
|
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
|
|
||||||
mockContext = createMockCommandContext({
|
mockContext = createMockCommandContext({
|
||||||
|
@ -223,7 +228,7 @@ describe('mcpCommand', () => {
|
||||||
|
|
||||||
// Server 2 - Connected
|
// Server 2 - Connected
|
||||||
expect(message).toContain(
|
expect(message).toContain(
|
||||||
'🟢 \u001b[1mserver2\u001b[0m - Ready (1 tools)',
|
'🟢 \u001b[1mserver2\u001b[0m - Ready (1 tool)',
|
||||||
);
|
);
|
||||||
expect(message).toContain('server2_tool1');
|
expect(message).toContain('server2_tool1');
|
||||||
|
|
||||||
|
@ -365,13 +370,13 @@ describe('mcpCommand', () => {
|
||||||
if (isMessageAction(result)) {
|
if (isMessageAction(result)) {
|
||||||
const message = result.content;
|
const message = result.content;
|
||||||
expect(message).toContain(
|
expect(message).toContain(
|
||||||
'🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)',
|
'🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)',
|
||||||
);
|
);
|
||||||
expect(message).toContain('\u001b[36mserver1_tool1\u001b[0m');
|
expect(message).toContain('\u001b[36mserver1_tool1\u001b[0m');
|
||||||
expect(message).toContain(
|
expect(message).toContain(
|
||||||
'🔴 \u001b[1mserver2\u001b[0m - Disconnected (0 tools cached)',
|
'🔴 \u001b[1mserver2\u001b[0m - Disconnected (0 tools cached)',
|
||||||
);
|
);
|
||||||
expect(message).toContain('No tools available');
|
expect(message).toContain('No tools or prompts available');
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -421,10 +426,10 @@ describe('mcpCommand', () => {
|
||||||
|
|
||||||
// Check server statuses
|
// Check server statuses
|
||||||
expect(message).toContain(
|
expect(message).toContain(
|
||||||
'🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)',
|
'🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)',
|
||||||
);
|
);
|
||||||
expect(message).toContain(
|
expect(message).toContain(
|
||||||
'🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools will appear when ready)',
|
'🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools and prompts will appear when ready)',
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -994,6 +999,9 @@ describe('mcpCommand', () => {
|
||||||
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
||||||
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
|
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
|
||||||
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
|
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
|
||||||
|
getPromptRegistry: vi.fn().mockResolvedValue({
|
||||||
|
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
|
@ -12,6 +12,7 @@ import {
|
||||||
MessageActionReturn,
|
MessageActionReturn,
|
||||||
} from './types.js';
|
} from './types.js';
|
||||||
import {
|
import {
|
||||||
|
DiscoveredMCPPrompt,
|
||||||
DiscoveredMCPTool,
|
DiscoveredMCPTool,
|
||||||
getMCPDiscoveryState,
|
getMCPDiscoveryState,
|
||||||
getMCPServerStatus,
|
getMCPServerStatus,
|
||||||
|
@ -101,6 +102,8 @@ const getMcpStatus = async (
|
||||||
(tool) =>
|
(tool) =>
|
||||||
tool instanceof DiscoveredMCPTool && tool.serverName === serverName,
|
tool instanceof DiscoveredMCPTool && tool.serverName === serverName,
|
||||||
) as DiscoveredMCPTool[];
|
) as DiscoveredMCPTool[];
|
||||||
|
const promptRegistry = await config.getPromptRegistry();
|
||||||
|
const serverPrompts = promptRegistry.getPromptsByServer(serverName) || [];
|
||||||
|
|
||||||
const status = getMCPServerStatus(serverName);
|
const status = getMCPServerStatus(serverName);
|
||||||
|
|
||||||
|
@ -160,9 +163,26 @@ const getMcpStatus = async (
|
||||||
|
|
||||||
// Add tool count with conditional messaging
|
// Add tool count with conditional messaging
|
||||||
if (status === MCPServerStatus.CONNECTED) {
|
if (status === MCPServerStatus.CONNECTED) {
|
||||||
message += ` (${serverTools.length} tools)`;
|
const parts = [];
|
||||||
|
if (serverTools.length > 0) {
|
||||||
|
parts.push(
|
||||||
|
`${serverTools.length} ${serverTools.length === 1 ? 'tool' : 'tools'}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (serverPrompts.length > 0) {
|
||||||
|
parts.push(
|
||||||
|
`${serverPrompts.length} ${
|
||||||
|
serverPrompts.length === 1 ? 'prompt' : 'prompts'
|
||||||
|
}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (parts.length > 0) {
|
||||||
|
message += ` (${parts.join(', ')})`;
|
||||||
|
} else {
|
||||||
|
message += ` (0 tools)`;
|
||||||
|
}
|
||||||
} else if (status === MCPServerStatus.CONNECTING) {
|
} else if (status === MCPServerStatus.CONNECTING) {
|
||||||
message += ` (tools will appear when ready)`;
|
message += ` (tools and prompts will appear when ready)`;
|
||||||
} else {
|
} else {
|
||||||
message += ` (${serverTools.length} tools cached)`;
|
message += ` (${serverTools.length} tools cached)`;
|
||||||
}
|
}
|
||||||
|
@ -186,6 +206,7 @@ const getMcpStatus = async (
|
||||||
message += RESET_COLOR;
|
message += RESET_COLOR;
|
||||||
|
|
||||||
if (serverTools.length > 0) {
|
if (serverTools.length > 0) {
|
||||||
|
message += ` ${COLOR_CYAN}Tools:${RESET_COLOR}\n`;
|
||||||
serverTools.forEach((tool) => {
|
serverTools.forEach((tool) => {
|
||||||
if (showDescriptions && tool.description) {
|
if (showDescriptions && tool.description) {
|
||||||
// Format tool name in cyan using simple ANSI cyan color
|
// Format tool name in cyan using simple ANSI cyan color
|
||||||
|
@ -222,12 +243,41 @@ const getMcpStatus = async (
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
if (serverPrompts.length > 0) {
|
||||||
|
if (serverTools.length > 0) {
|
||||||
|
message += '\n';
|
||||||
|
}
|
||||||
|
message += ` ${COLOR_CYAN}Prompts:${RESET_COLOR}\n`;
|
||||||
|
serverPrompts.forEach((prompt: DiscoveredMCPPrompt) => {
|
||||||
|
if (showDescriptions && prompt.description) {
|
||||||
|
message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}`;
|
||||||
|
const descLines = prompt.description.trim().split('\n');
|
||||||
|
if (descLines) {
|
||||||
|
message += ':\n';
|
||||||
|
for (const descLine of descLines) {
|
||||||
|
message += ` ${COLOR_GREEN}${descLine}${RESET_COLOR}\n`;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
message += '\n';
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}\n`;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (serverTools.length === 0 && serverPrompts.length === 0) {
|
||||||
|
message += ' No tools or prompts available\n';
|
||||||
|
} else if (serverTools.length === 0) {
|
||||||
message += ' No tools available';
|
message += ' No tools available';
|
||||||
if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) {
|
if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) {
|
||||||
message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}`;
|
message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}`;
|
||||||
}
|
}
|
||||||
message += '\n';
|
message += '\n';
|
||||||
|
} else if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) {
|
||||||
|
// This case is for when serverTools.length > 0
|
||||||
|
message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}\n`;
|
||||||
}
|
}
|
||||||
message += '\n';
|
message += '\n';
|
||||||
}
|
}
|
||||||
|
@ -328,11 +378,10 @@ const authCommand: SlashCommand = {
|
||||||
// Import dynamically to avoid circular dependencies
|
// Import dynamically to avoid circular dependencies
|
||||||
const { MCPOAuthProvider } = await import('@google/gemini-cli-core');
|
const { MCPOAuthProvider } = await import('@google/gemini-cli-core');
|
||||||
|
|
||||||
// Create OAuth config for authentication (will be discovered automatically)
|
let oauthConfig = server.oauth;
|
||||||
const oauthConfig = server.oauth || {
|
if (!oauthConfig) {
|
||||||
authorizationUrl: '', // Will be discovered automatically
|
oauthConfig = { enabled: false };
|
||||||
tokenUrl: '', // Will be discovered automatically
|
}
|
||||||
};
|
|
||||||
|
|
||||||
// Pass the MCP server URL for OAuth discovery
|
// Pass the MCP server URL for OAuth discovery
|
||||||
const mcpServerUrl = server.httpUrl || server.url;
|
const mcpServerUrl = server.httpUrl || server.url;
|
||||||
|
|
|
@ -128,6 +128,7 @@ export type SlashCommandActionReturn =
|
||||||
export enum CommandKind {
|
export enum CommandKind {
|
||||||
BUILT_IN = 'built-in',
|
BUILT_IN = 'built-in',
|
||||||
FILE = 'file',
|
FILE = 'file',
|
||||||
|
MCP_PROMPT = 'mcp-prompt',
|
||||||
}
|
}
|
||||||
|
|
||||||
// The standardized contract for any command in the system.
|
// The standardized contract for any command in the system.
|
||||||
|
|
|
@ -28,6 +28,13 @@ vi.mock('../../services/FileCommandLoader.js', () => ({
|
||||||
})),
|
})),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
const mockMcpLoadCommands = vi.fn();
|
||||||
|
vi.mock('../../services/McpPromptLoader.js', () => ({
|
||||||
|
McpPromptLoader: vi.fn().mockImplementation(() => ({
|
||||||
|
loadCommands: mockMcpLoadCommands,
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
vi.mock('../contexts/SessionContext.js', () => ({
|
vi.mock('../contexts/SessionContext.js', () => ({
|
||||||
useSessionStats: vi.fn(() => ({ stats: {} })),
|
useSessionStats: vi.fn(() => ({ stats: {} })),
|
||||||
}));
|
}));
|
||||||
|
@ -41,6 +48,7 @@ import { LoadedSettings } from '../../config/settings.js';
|
||||||
import { MessageType } from '../types.js';
|
import { MessageType } from '../types.js';
|
||||||
import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
|
import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
|
||||||
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
|
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
|
||||||
|
import { McpPromptLoader } from '../../services/McpPromptLoader.js';
|
||||||
|
|
||||||
const createTestCommand = (
|
const createTestCommand = (
|
||||||
overrides: Partial<SlashCommand>,
|
overrides: Partial<SlashCommand>,
|
||||||
|
@ -75,14 +83,17 @@ describe('useSlashCommandProcessor', () => {
|
||||||
(vi.mocked(BuiltinCommandLoader) as Mock).mockClear();
|
(vi.mocked(BuiltinCommandLoader) as Mock).mockClear();
|
||||||
mockBuiltinLoadCommands.mockResolvedValue([]);
|
mockBuiltinLoadCommands.mockResolvedValue([]);
|
||||||
mockFileLoadCommands.mockResolvedValue([]);
|
mockFileLoadCommands.mockResolvedValue([]);
|
||||||
|
mockMcpLoadCommands.mockResolvedValue([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
const setupProcessorHook = (
|
const setupProcessorHook = (
|
||||||
builtinCommands: SlashCommand[] = [],
|
builtinCommands: SlashCommand[] = [],
|
||||||
fileCommands: SlashCommand[] = [],
|
fileCommands: SlashCommand[] = [],
|
||||||
|
mcpCommands: SlashCommand[] = [],
|
||||||
) => {
|
) => {
|
||||||
mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands));
|
mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands));
|
||||||
mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands));
|
mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands));
|
||||||
|
mockMcpLoadCommands.mockResolvedValue(Object.freeze(mcpCommands));
|
||||||
|
|
||||||
const { result } = renderHook(() =>
|
const { result } = renderHook(() =>
|
||||||
useSlashCommandProcessor(
|
useSlashCommandProcessor(
|
||||||
|
@ -111,6 +122,7 @@ describe('useSlashCommandProcessor', () => {
|
||||||
setupProcessorHook();
|
setupProcessorHook();
|
||||||
expect(BuiltinCommandLoader).toHaveBeenCalledWith(mockConfig);
|
expect(BuiltinCommandLoader).toHaveBeenCalledWith(mockConfig);
|
||||||
expect(FileCommandLoader).toHaveBeenCalledWith(mockConfig);
|
expect(FileCommandLoader).toHaveBeenCalledWith(mockConfig);
|
||||||
|
expect(McpPromptLoader).toHaveBeenCalledWith(mockConfig);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should call loadCommands and populate state after mounting', async () => {
|
it('should call loadCommands and populate state after mounting', async () => {
|
||||||
|
@ -124,6 +136,7 @@ describe('useSlashCommandProcessor', () => {
|
||||||
expect(result.current.slashCommands[0]?.name).toBe('test');
|
expect(result.current.slashCommands[0]?.name).toBe('test');
|
||||||
expect(mockBuiltinLoadCommands).toHaveBeenCalledTimes(1);
|
expect(mockBuiltinLoadCommands).toHaveBeenCalledTimes(1);
|
||||||
expect(mockFileLoadCommands).toHaveBeenCalledTimes(1);
|
expect(mockFileLoadCommands).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockMcpLoadCommands).toHaveBeenCalledTimes(1);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should provide an immutable array of commands to consumers', async () => {
|
it('should provide an immutable array of commands to consumers', async () => {
|
||||||
|
@ -369,6 +382,38 @@ describe('useSlashCommandProcessor', () => {
|
||||||
expect.any(Number),
|
expect.any(Number),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle "submit_prompt" action returned from a mcp-based command', async () => {
|
||||||
|
const mcpCommand = createTestCommand(
|
||||||
|
{
|
||||||
|
name: 'mcpcmd',
|
||||||
|
description: 'A command from mcp',
|
||||||
|
action: async () => ({
|
||||||
|
type: 'submit_prompt',
|
||||||
|
content: 'The actual prompt from the mcp command.',
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
CommandKind.MCP_PROMPT,
|
||||||
|
);
|
||||||
|
|
||||||
|
const result = setupProcessorHook([], [], [mcpCommand]);
|
||||||
|
await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
|
||||||
|
|
||||||
|
let actionResult;
|
||||||
|
await act(async () => {
|
||||||
|
actionResult = await result.current.handleSlashCommand('/mcpcmd');
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(actionResult).toEqual({
|
||||||
|
type: 'submit_prompt',
|
||||||
|
content: 'The actual prompt from the mcp command.',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockAddItem).toHaveBeenCalledWith(
|
||||||
|
{ type: MessageType.USER, text: '/mcpcmd' },
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('Command Parsing and Matching', () => {
|
describe('Command Parsing and Matching', () => {
|
||||||
|
@ -441,6 +486,39 @@ describe('useSlashCommandProcessor', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('Command Precedence', () => {
|
describe('Command Precedence', () => {
|
||||||
|
it('should override mcp-based commands with file-based commands of the same name', async () => {
|
||||||
|
const mcpAction = vi.fn();
|
||||||
|
const fileAction = vi.fn();
|
||||||
|
|
||||||
|
const mcpCommand = createTestCommand(
|
||||||
|
{
|
||||||
|
name: 'override',
|
||||||
|
description: 'mcp',
|
||||||
|
action: mcpAction,
|
||||||
|
},
|
||||||
|
CommandKind.MCP_PROMPT,
|
||||||
|
);
|
||||||
|
const fileCommand = createTestCommand(
|
||||||
|
{ name: 'override', description: 'file', action: fileAction },
|
||||||
|
CommandKind.FILE,
|
||||||
|
);
|
||||||
|
|
||||||
|
const result = setupProcessorHook([], [fileCommand], [mcpCommand]);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
// The service should only return one command with the name 'override'
|
||||||
|
expect(result.current.slashCommands).toHaveLength(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.handleSlashCommand('/override');
|
||||||
|
});
|
||||||
|
|
||||||
|
// Only the file-based command's action should be called.
|
||||||
|
expect(fileAction).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mcpAction).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
it('should prioritize a command with a primary name over a command with a matching alias', async () => {
|
it('should prioritize a command with a primary name over a command with a matching alias', async () => {
|
||||||
const quitAction = vi.fn();
|
const quitAction = vi.fn();
|
||||||
const exitAction = vi.fn();
|
const exitAction = vi.fn();
|
||||||
|
|
|
@ -23,6 +23,7 @@ import { type CommandContext, type SlashCommand } from '../commands/types.js';
|
||||||
import { CommandService } from '../../services/CommandService.js';
|
import { CommandService } from '../../services/CommandService.js';
|
||||||
import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
|
import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
|
||||||
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
|
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
|
||||||
|
import { McpPromptLoader } from '../../services/McpPromptLoader.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hook to define and process slash commands (e.g., /help, /clear).
|
* Hook to define and process slash commands (e.g., /help, /clear).
|
||||||
|
@ -164,6 +165,7 @@ export const useSlashCommandProcessor = (
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
const load = async () => {
|
const load = async () => {
|
||||||
const loaders = [
|
const loaders = [
|
||||||
|
new McpPromptLoader(config),
|
||||||
new BuiltinCommandLoader(config),
|
new BuiltinCommandLoader(config),
|
||||||
new FileCommandLoader(config),
|
new FileCommandLoader(config),
|
||||||
];
|
];
|
||||||
|
@ -246,6 +248,7 @@ export const useSlashCommandProcessor = (
|
||||||
args,
|
args,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
try {
|
||||||
const result = await commandToExecute.action(
|
const result = await commandToExecute.action(
|
||||||
fullCommandContext,
|
fullCommandContext,
|
||||||
args,
|
args,
|
||||||
|
@ -319,10 +322,22 @@ export const useSlashCommandProcessor = (
|
||||||
};
|
};
|
||||||
default: {
|
default: {
|
||||||
const unhandled: never = result;
|
const unhandled: never = result;
|
||||||
throw new Error(`Unhandled slash command result: ${unhandled}`);
|
throw new Error(
|
||||||
|
`Unhandled slash command result: ${unhandled}`,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (e) {
|
||||||
|
addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.ERROR,
|
||||||
|
text: e instanceof Error ? e.message : String(e),
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
return { type: 'handled' };
|
||||||
|
}
|
||||||
|
|
||||||
return { type: 'handled' };
|
return { type: 'handled' };
|
||||||
} else if (commandToExecute.subCommands) {
|
} else if (commandToExecute.subCommands) {
|
||||||
|
|
|
@ -1100,7 +1100,7 @@ describe('useCompletion', () => {
|
||||||
result.current.handleAutocomplete(0);
|
result.current.handleAutocomplete(0);
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockBuffer.setText).toHaveBeenCalledWith('/memory');
|
expect(mockBuffer.setText).toHaveBeenCalledWith('/memory ');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should append a sub-command when the parent is complete', () => {
|
it('should append a sub-command when the parent is complete', () => {
|
||||||
|
@ -1145,7 +1145,7 @@ describe('useCompletion', () => {
|
||||||
result.current.handleAutocomplete(1); // index 1 is 'add'
|
result.current.handleAutocomplete(1); // index 1 is 'add'
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add');
|
expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add ');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should complete a command with an alternative name', () => {
|
it('should complete a command with an alternative name', () => {
|
||||||
|
@ -1190,7 +1190,7 @@ describe('useCompletion', () => {
|
||||||
result.current.handleAutocomplete(0);
|
result.current.handleAutocomplete(0);
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(mockBuffer.setText).toHaveBeenCalledWith('/help');
|
expect(mockBuffer.setText).toHaveBeenCalledWith('/help ');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should complete a file path', async () => {
|
it('should complete a file path', async () => {
|
||||||
|
|
|
@ -638,10 +638,17 @@ export function useCompletion(
|
||||||
// Determine the base path of the command.
|
// Determine the base path of the command.
|
||||||
// - If there's a trailing space, the whole command is the base.
|
// - If there's a trailing space, the whole command is the base.
|
||||||
// - If it's a known parent path, the whole command is the base.
|
// - If it's a known parent path, the whole command is the base.
|
||||||
|
// - If the last part is a complete argument, the whole command is the base.
|
||||||
// - Otherwise, the base is everything EXCEPT the last partial part.
|
// - Otherwise, the base is everything EXCEPT the last partial part.
|
||||||
|
const lastPart = parts.length > 0 ? parts[parts.length - 1] : '';
|
||||||
|
const isLastPartACompleteArg =
|
||||||
|
lastPart.startsWith('--') && lastPart.includes('=');
|
||||||
|
|
||||||
const basePath =
|
const basePath =
|
||||||
hasTrailingSpace || isParentPath ? parts : parts.slice(0, -1);
|
hasTrailingSpace || isParentPath || isLastPartACompleteArg
|
||||||
const newValue = `/${[...basePath, suggestion].join(' ')}`;
|
? parts
|
||||||
|
: parts.slice(0, -1);
|
||||||
|
const newValue = `/${[...basePath, suggestion].join(' ')} `;
|
||||||
|
|
||||||
buffer.setText(newValue);
|
buffer.setText(newValue);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import {
|
||||||
ContentGeneratorConfig,
|
ContentGeneratorConfig,
|
||||||
createContentGeneratorConfig,
|
createContentGeneratorConfig,
|
||||||
} from '../core/contentGenerator.js';
|
} from '../core/contentGenerator.js';
|
||||||
|
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
import { LSTool } from '../tools/ls.js';
|
import { LSTool } from '../tools/ls.js';
|
||||||
import { ReadFileTool } from '../tools/read-file.js';
|
import { ReadFileTool } from '../tools/read-file.js';
|
||||||
|
@ -186,6 +187,7 @@ export interface ConfigParameters {
|
||||||
|
|
||||||
export class Config {
|
export class Config {
|
||||||
private toolRegistry!: ToolRegistry;
|
private toolRegistry!: ToolRegistry;
|
||||||
|
private promptRegistry!: PromptRegistry;
|
||||||
private readonly sessionId: string;
|
private readonly sessionId: string;
|
||||||
private contentGeneratorConfig!: ContentGeneratorConfig;
|
private contentGeneratorConfig!: ContentGeneratorConfig;
|
||||||
private readonly embeddingModel: string;
|
private readonly embeddingModel: string;
|
||||||
|
@ -314,6 +316,7 @@ export class Config {
|
||||||
if (this.getCheckpointingEnabled()) {
|
if (this.getCheckpointingEnabled()) {
|
||||||
await this.getGitService();
|
await this.getGitService();
|
||||||
}
|
}
|
||||||
|
this.promptRegistry = new PromptRegistry();
|
||||||
this.toolRegistry = await this.createToolRegistry();
|
this.toolRegistry = await this.createToolRegistry();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -396,6 +399,10 @@ export class Config {
|
||||||
return Promise.resolve(this.toolRegistry);
|
return Promise.resolve(this.toolRegistry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getPromptRegistry(): PromptRegistry {
|
||||||
|
return this.promptRegistry;
|
||||||
|
}
|
||||||
|
|
||||||
getDebugMode(): boolean {
|
getDebugMode(): boolean {
|
||||||
return this.debugMode;
|
return this.debugMode;
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,6 +49,9 @@ export * from './ide/ideContext.js';
|
||||||
export * from './tools/tools.js';
|
export * from './tools/tools.js';
|
||||||
export * from './tools/tool-registry.js';
|
export * from './tools/tool-registry.js';
|
||||||
|
|
||||||
|
// Export prompt logic
|
||||||
|
export * from './prompts/mcp-prompts.js';
|
||||||
|
|
||||||
// Export specific tool logic
|
// Export specific tool logic
|
||||||
export * from './tools/read-file.js';
|
export * from './tools/read-file.js';
|
||||||
export * from './tools/ls.js';
|
export * from './tools/ls.js';
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { Config } from '../config/config.js';
|
||||||
|
import { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
|
||||||
|
|
||||||
|
export function getMCPServerPrompts(
|
||||||
|
config: Config,
|
||||||
|
serverName: string,
|
||||||
|
): DiscoveredMCPPrompt[] {
|
||||||
|
const promptRegistry = config.getPromptRegistry();
|
||||||
|
if (!promptRegistry) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
return promptRegistry.getPromptsByServer(serverName);
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { DiscoveredMCPPrompt } from '../tools/mcp-client.js';
|
||||||
|
|
||||||
|
export class PromptRegistry {
|
||||||
|
private prompts: Map<string, DiscoveredMCPPrompt> = new Map();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Registers a prompt definition.
|
||||||
|
* @param prompt - The prompt object containing schema and execution logic.
|
||||||
|
*/
|
||||||
|
registerPrompt(prompt: DiscoveredMCPPrompt): void {
|
||||||
|
if (this.prompts.has(prompt.name)) {
|
||||||
|
const newName = `${prompt.serverName}_${prompt.name}`;
|
||||||
|
console.warn(
|
||||||
|
`Prompt with name "${prompt.name}" is already registered. Renaming to "${newName}".`,
|
||||||
|
);
|
||||||
|
this.prompts.set(newName, { ...prompt, name: newName });
|
||||||
|
} else {
|
||||||
|
this.prompts.set(prompt.name, prompt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an array of all registered and discovered prompt instances.
|
||||||
|
*/
|
||||||
|
getAllPrompts(): DiscoveredMCPPrompt[] {
|
||||||
|
return Array.from(this.prompts.values()).sort((a, b) =>
|
||||||
|
a.name.localeCompare(b.name),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the definition of a specific prompt.
|
||||||
|
*/
|
||||||
|
getPrompt(name: string): DiscoveredMCPPrompt | undefined {
|
||||||
|
return this.prompts.get(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns an array of prompts registered from a specific MCP server.
|
||||||
|
*/
|
||||||
|
getPromptsByServer(serverName: string): DiscoveredMCPPrompt[] {
|
||||||
|
const serverPrompts: DiscoveredMCPPrompt[] = [];
|
||||||
|
for (const prompt of this.prompts.values()) {
|
||||||
|
if (prompt.serverName === serverName) {
|
||||||
|
serverPrompts.push(prompt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return serverPrompts.sort((a, b) => a.name.localeCompare(b.name));
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ import {
|
||||||
createTransport,
|
createTransport,
|
||||||
isEnabled,
|
isEnabled,
|
||||||
discoverTools,
|
discoverTools,
|
||||||
|
discoverPrompts,
|
||||||
} from './mcp-client.js';
|
} from './mcp-client.js';
|
||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||||
|
@ -18,6 +19,7 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||||
import * as GenAiLib from '@google/genai';
|
import * as GenAiLib from '@google/genai';
|
||||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||||
import { AuthProviderType } from '../config/config.js';
|
import { AuthProviderType } from '../config/config.js';
|
||||||
|
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
|
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
vi.mock('@modelcontextprotocol/sdk/client/stdio.js');
|
||||||
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
vi.mock('@modelcontextprotocol/sdk/client/index.js');
|
||||||
|
@ -50,6 +52,77 @@ describe('mcp-client', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('discoverPrompts', () => {
|
||||||
|
const mockedPromptRegistry = {
|
||||||
|
registerPrompt: vi.fn(),
|
||||||
|
} as unknown as PromptRegistry;
|
||||||
|
|
||||||
|
it('should discover and log prompts', async () => {
|
||||||
|
const mockRequest = vi.fn().mockResolvedValue({
|
||||||
|
prompts: [
|
||||||
|
{ name: 'prompt1', description: 'desc1' },
|
||||||
|
{ name: 'prompt2' },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
const mockedClient = {
|
||||||
|
request: mockRequest,
|
||||||
|
} as unknown as ClientLib.Client;
|
||||||
|
|
||||||
|
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||||
|
|
||||||
|
expect(mockRequest).toHaveBeenCalledWith(
|
||||||
|
{ method: 'prompts/list', params: {} },
|
||||||
|
expect.anything(),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should do nothing if no prompts are discovered', async () => {
|
||||||
|
const mockRequest = vi.fn().mockResolvedValue({
|
||||||
|
prompts: [],
|
||||||
|
});
|
||||||
|
const mockedClient = {
|
||||||
|
request: mockRequest,
|
||||||
|
} as unknown as ClientLib.Client;
|
||||||
|
|
||||||
|
const consoleLogSpy = vi
|
||||||
|
.spyOn(console, 'debug')
|
||||||
|
.mockImplementation(() => {
|
||||||
|
// no-op
|
||||||
|
});
|
||||||
|
|
||||||
|
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||||
|
|
||||||
|
expect(mockRequest).toHaveBeenCalledOnce();
|
||||||
|
expect(consoleLogSpy).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
consoleLogSpy.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should log an error if discovery fails', async () => {
|
||||||
|
const testError = new Error('test error');
|
||||||
|
testError.message = 'test error';
|
||||||
|
const mockRequest = vi.fn().mockRejectedValue(testError);
|
||||||
|
const mockedClient = {
|
||||||
|
request: mockRequest,
|
||||||
|
} as unknown as ClientLib.Client;
|
||||||
|
|
||||||
|
const consoleErrorSpy = vi
|
||||||
|
.spyOn(console, 'error')
|
||||||
|
.mockImplementation(() => {
|
||||||
|
// no-op
|
||||||
|
});
|
||||||
|
|
||||||
|
await discoverPrompts('test-server', mockedClient, mockedPromptRegistry);
|
||||||
|
|
||||||
|
expect(mockRequest).toHaveBeenCalledOnce();
|
||||||
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||||
|
`Error discovering prompts from test-server: ${testError.message}`,
|
||||||
|
);
|
||||||
|
|
||||||
|
consoleErrorSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('appendMcpServerCommand', () => {
|
describe('appendMcpServerCommand', () => {
|
||||||
it('should do nothing if no MCP servers or command are configured', () => {
|
it('should do nothing if no MCP servers or command are configured', () => {
|
||||||
const out = populateMcpServerCommand({}, undefined);
|
const out = populateMcpServerCommand({}, undefined);
|
||||||
|
|
|
@ -15,12 +15,20 @@ import {
|
||||||
StreamableHTTPClientTransport,
|
StreamableHTTPClientTransport,
|
||||||
StreamableHTTPClientTransportOptions,
|
StreamableHTTPClientTransportOptions,
|
||||||
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
} from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||||
|
import {
|
||||||
|
Prompt,
|
||||||
|
ListPromptsResultSchema,
|
||||||
|
GetPromptResult,
|
||||||
|
GetPromptResultSchema,
|
||||||
|
} from '@modelcontextprotocol/sdk/types.js';
|
||||||
import { parse } from 'shell-quote';
|
import { parse } from 'shell-quote';
|
||||||
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
|
import { AuthProviderType, MCPServerConfig } from '../config/config.js';
|
||||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||||
|
|
||||||
import { FunctionDeclaration, mcpToTool } from '@google/genai';
|
import { FunctionDeclaration, mcpToTool } from '@google/genai';
|
||||||
import { ToolRegistry } from './tool-registry.js';
|
import { ToolRegistry } from './tool-registry.js';
|
||||||
|
import { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||||
|
@ -28,6 +36,11 @@ import { getErrorMessage } from '../utils/errors.js';
|
||||||
|
|
||||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||||
|
|
||||||
|
export type DiscoveredMCPPrompt = Prompt & {
|
||||||
|
serverName: string;
|
||||||
|
invoke: (params: Record<string, unknown>) => Promise<GetPromptResult>;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Enum representing the connection status of an MCP server
|
* Enum representing the connection status of an MCP server
|
||||||
*/
|
*/
|
||||||
|
@ -55,7 +68,7 @@ export enum MCPDiscoveryState {
|
||||||
/**
|
/**
|
||||||
* Map to track the status of each MCP server within the core package
|
* Map to track the status of each MCP server within the core package
|
||||||
*/
|
*/
|
||||||
const mcpServerStatusesInternal: Map<string, MCPServerStatus> = new Map();
|
const serverStatuses: Map<string, MCPServerStatus> = new Map();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Track the overall MCP discovery state
|
* Track the overall MCP discovery state
|
||||||
|
@ -104,7 +117,7 @@ function updateMCPServerStatus(
|
||||||
serverName: string,
|
serverName: string,
|
||||||
status: MCPServerStatus,
|
status: MCPServerStatus,
|
||||||
): void {
|
): void {
|
||||||
mcpServerStatusesInternal.set(serverName, status);
|
serverStatuses.set(serverName, status);
|
||||||
// Notify all listeners
|
// Notify all listeners
|
||||||
for (const listener of statusChangeListeners) {
|
for (const listener of statusChangeListeners) {
|
||||||
listener(serverName, status);
|
listener(serverName, status);
|
||||||
|
@ -115,16 +128,14 @@ function updateMCPServerStatus(
|
||||||
* Get the current status of an MCP server
|
* Get the current status of an MCP server
|
||||||
*/
|
*/
|
||||||
export function getMCPServerStatus(serverName: string): MCPServerStatus {
|
export function getMCPServerStatus(serverName: string): MCPServerStatus {
|
||||||
return (
|
return serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED;
|
||||||
mcpServerStatusesInternal.get(serverName) || MCPServerStatus.DISCONNECTED
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get all MCP server statuses
|
* Get all MCP server statuses
|
||||||
*/
|
*/
|
||||||
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
|
export function getAllMCPServerStatuses(): Map<string, MCPServerStatus> {
|
||||||
return new Map(mcpServerStatusesInternal);
|
return new Map(serverStatuses);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -307,6 +318,7 @@ export async function discoverMcpTools(
|
||||||
mcpServers: Record<string, MCPServerConfig>,
|
mcpServers: Record<string, MCPServerConfig>,
|
||||||
mcpServerCommand: string | undefined,
|
mcpServerCommand: string | undefined,
|
||||||
toolRegistry: ToolRegistry,
|
toolRegistry: ToolRegistry,
|
||||||
|
promptRegistry: PromptRegistry,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||||
|
@ -319,6 +331,7 @@ export async function discoverMcpTools(
|
||||||
mcpServerName,
|
mcpServerName,
|
||||||
mcpServerConfig,
|
mcpServerConfig,
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
|
promptRegistry,
|
||||||
debugMode,
|
debugMode,
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
|
@ -362,6 +375,7 @@ export async function connectAndDiscover(
|
||||||
mcpServerName: string,
|
mcpServerName: string,
|
||||||
mcpServerConfig: MCPServerConfig,
|
mcpServerConfig: MCPServerConfig,
|
||||||
toolRegistry: ToolRegistry,
|
toolRegistry: ToolRegistry,
|
||||||
|
promptRegistry: PromptRegistry,
|
||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||||
|
@ -378,6 +392,7 @@ export async function connectAndDiscover(
|
||||||
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
|
console.error(`MCP ERROR (${mcpServerName}):`, error.toString());
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||||
};
|
};
|
||||||
|
await discoverPrompts(mcpServerName, mcpClient, promptRegistry);
|
||||||
|
|
||||||
const tools = await discoverTools(
|
const tools = await discoverTools(
|
||||||
mcpServerName,
|
mcpServerName,
|
||||||
|
@ -393,7 +408,9 @@ export async function connectAndDiscover(
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(
|
console.error(
|
||||||
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(error)}`,
|
`Error connecting to MCP server '${mcpServerName}': ${getErrorMessage(
|
||||||
|
error,
|
||||||
|
)}`,
|
||||||
);
|
);
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||||
}
|
}
|
||||||
|
@ -441,15 +458,97 @@ export async function discoverTools(
|
||||||
),
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (discoveredTools.length === 0) {
|
|
||||||
throw Error('No enabled tools found');
|
|
||||||
}
|
|
||||||
return discoveredTools;
|
return discoveredTools;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
throw new Error(`Error discovering tools: ${error}`);
|
throw new Error(`Error discovering tools: ${error}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discovers and logs prompts from a connected MCP client.
|
||||||
|
* It retrieves prompt declarations from the client and logs their names.
|
||||||
|
*
|
||||||
|
* @param mcpServerName The name of the MCP server.
|
||||||
|
* @param mcpClient The active MCP client instance.
|
||||||
|
*/
|
||||||
|
export async function discoverPrompts(
|
||||||
|
mcpServerName: string,
|
||||||
|
mcpClient: Client,
|
||||||
|
promptRegistry: PromptRegistry,
|
||||||
|
): Promise<void> {
|
||||||
|
try {
|
||||||
|
const response = await mcpClient.request(
|
||||||
|
{ method: 'prompts/list', params: {} },
|
||||||
|
ListPromptsResultSchema,
|
||||||
|
);
|
||||||
|
|
||||||
|
for (const prompt of response.prompts) {
|
||||||
|
promptRegistry.registerPrompt({
|
||||||
|
...prompt,
|
||||||
|
serverName: mcpServerName,
|
||||||
|
invoke: (params: Record<string, unknown>) =>
|
||||||
|
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// It's okay if this fails, not all servers will have prompts.
|
||||||
|
// Don't log an error if the method is not found, which is a common case.
|
||||||
|
if (
|
||||||
|
error instanceof Error &&
|
||||||
|
!error.message?.includes('Method not found')
|
||||||
|
) {
|
||||||
|
console.error(
|
||||||
|
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
|
||||||
|
error,
|
||||||
|
)}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Invokes a prompt on a connected MCP client.
|
||||||
|
*
|
||||||
|
* @param mcpServerName The name of the MCP server.
|
||||||
|
* @param mcpClient The active MCP client instance.
|
||||||
|
* @param promptName The name of the prompt to invoke.
|
||||||
|
* @param promptParams The parameters to pass to the prompt.
|
||||||
|
* @returns A promise that resolves to the result of the prompt invocation.
|
||||||
|
*/
|
||||||
|
export async function invokeMcpPrompt(
|
||||||
|
mcpServerName: string,
|
||||||
|
mcpClient: Client,
|
||||||
|
promptName: string,
|
||||||
|
promptParams: Record<string, unknown>,
|
||||||
|
): Promise<GetPromptResult> {
|
||||||
|
try {
|
||||||
|
const response = await mcpClient.request(
|
||||||
|
{
|
||||||
|
method: 'prompts/get',
|
||||||
|
params: {
|
||||||
|
name: promptName,
|
||||||
|
arguments: promptParams,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
GetPromptResultSchema,
|
||||||
|
);
|
||||||
|
|
||||||
|
return response;
|
||||||
|
} catch (error) {
|
||||||
|
if (
|
||||||
|
error instanceof Error &&
|
||||||
|
!error.message?.includes('Method not found')
|
||||||
|
) {
|
||||||
|
console.error(
|
||||||
|
`Error invoking prompt '${promptName}' from ${mcpServerName} ${promptParams}: ${getErrorMessage(
|
||||||
|
error,
|
||||||
|
)}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates and connects an MCP client to a server based on the provided configuration.
|
* Creates and connects an MCP client to a server based on the provided configuration.
|
||||||
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
|
* It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and
|
||||||
|
|
|
@ -344,6 +344,7 @@ describe('ToolRegistry', () => {
|
||||||
mcpServerConfigVal,
|
mcpServerConfigVal,
|
||||||
undefined,
|
undefined,
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
|
undefined,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
@ -366,6 +367,7 @@ describe('ToolRegistry', () => {
|
||||||
mcpServerConfigVal,
|
mcpServerConfigVal,
|
||||||
undefined,
|
undefined,
|
||||||
toolRegistry,
|
toolRegistry,
|
||||||
|
undefined,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
|
@ -170,6 +170,7 @@ export class ToolRegistry {
|
||||||
this.config.getMcpServers() ?? {},
|
this.config.getMcpServers() ?? {},
|
||||||
this.config.getMcpServerCommand(),
|
this.config.getMcpServerCommand(),
|
||||||
this,
|
this,
|
||||||
|
this.config.getPromptRegistry(),
|
||||||
this.config.getDebugMode(),
|
this.config.getDebugMode(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -192,6 +193,7 @@ export class ToolRegistry {
|
||||||
this.config.getMcpServers() ?? {},
|
this.config.getMcpServers() ?? {},
|
||||||
this.config.getMcpServerCommand(),
|
this.config.getMcpServerCommand(),
|
||||||
this,
|
this,
|
||||||
|
this.config.getPromptRegistry(),
|
||||||
this.config.getDebugMode(),
|
this.config.getDebugMode(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -215,6 +217,7 @@ export class ToolRegistry {
|
||||||
{ [serverName]: serverConfig },
|
{ [serverName]: serverConfig },
|
||||||
undefined,
|
undefined,
|
||||||
this,
|
this,
|
||||||
|
this.config.getPromptRegistry(),
|
||||||
this.config.getDebugMode(),
|
this.config.getDebugMode(),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue