From 229ae03631b40f6997ca7244517a6a6f9b368f74 Mon Sep 17 00:00:00 2001 From: Tyler Date: Mon, 7 Jul 2025 09:45:58 -0700 Subject: [PATCH] Add --allowed_mcp_server_names flag (#3464) --- packages/cli/src/config/config.test.ts | 77 ++++++++++++++++++++++++++ packages/cli/src/config/config.ts | 20 ++++++- 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index d4820726..c08266d2 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -478,3 +478,80 @@ describe('mergeExcludeTools', () => { expect(settings).toEqual(originalSettings); }); }); + +describe('loadCliConfig with allowed_mcp_server_names', () => { + const originalArgv = process.argv; + const originalEnv = { ...process.env }; + + beforeEach(() => { + vi.resetAllMocks(); + vi.mocked(os.homedir).mockReturnValue('/mock/home/user'); + process.env.GEMINI_API_KEY = 'test-api-key'; + }); + + afterEach(() => { + process.argv = originalArgv; + process.env = originalEnv; + vi.restoreAllMocks(); + }); + + const baseSettings: Settings = { + mcpServers: { + server1: { url: 'http://localhost:8080' }, + server2: { url: 'http://localhost:8081' }, + server3: { url: 'http://localhost:8082' }, + }, + }; + + it('should allow all MCP servers if the flag is not provided', async () => { + process.argv = ['node', 'script.js']; + const config = await loadCliConfig(baseSettings, [], 'test-session'); + expect(config.getMcpServers()).toEqual(baseSettings.mcpServers); + }); + + it('should allow only the specified MCP server', async () => { + process.argv = [ + 'node', + 'script.js', + '--allowed_mcp_server_names', + 'server1', + ]; + const config = await loadCliConfig(baseSettings, [], 'test-session'); + expect(config.getMcpServers()).toEqual({ + server1: { url: 'http://localhost:8080' }, + }); + }); + + it('should allow multiple specified MCP servers', async () => { + process.argv = [ + 'node', + 'script.js', + '--allowed_mcp_server_names', + 'server1,server3', + ]; + const config = await loadCliConfig(baseSettings, [], 'test-session'); + expect(config.getMcpServers()).toEqual({ + server1: { url: 'http://localhost:8080' }, + server3: { url: 'http://localhost:8082' }, + }); + }); + + it('should handle server names that do not exist', async () => { + process.argv = [ + 'node', + 'script.js', + '--allowed_mcp_server_names', + 'server1,server4', + ]; + const config = await loadCliConfig(baseSettings, [], 'test-session'); + expect(config.getMcpServers()).toEqual({ + server1: { url: 'http://localhost:8080' }, + }); + }); + + it('should allow all MCP servers if the flag is an empty string', async () => { + process.argv = ['node', 'script.js', '--allowed_mcp_server_names', '']; + const config = await loadCliConfig(baseSettings, [], 'test-session'); + expect(config.getMcpServers()).toEqual(baseSettings.mcpServers); + }); +}); diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 7eed1db7..b32ae50c 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -48,6 +48,7 @@ interface CliArgs { telemetryTarget: string | undefined; telemetryOtlpEndpoint: string | undefined; telemetryLogPrompts: boolean | undefined; + allowed_mcp_server_names: string | undefined; } async function parseArguments(): Promise { @@ -123,6 +124,10 @@ async function parseArguments(): Promise { description: 'Enables checkpointing of file edits', default: false, }) + .option('allowed_mcp_server_names', { + type: 'string', + description: 'Allowed MCP server names', + }) .version(await getCliVersion()) // This will enable the --version flag based on package.json .alias('v', 'version') .help() @@ -186,9 +191,22 @@ export async function loadCliConfig( extensionContextFilePaths, ); - const mcpServers = mergeMcpServers(settings, extensions); + let mcpServers = mergeMcpServers(settings, extensions); const excludeTools = mergeExcludeTools(settings, extensions); + if (argv.allowed_mcp_server_names) { + const allowedNames = new Set( + argv.allowed_mcp_server_names.split(',').filter(Boolean), + ); + if (allowedNames.size > 0) { + mcpServers = Object.fromEntries( + Object.entries(mcpServers).filter(([key]) => allowedNames.has(key)), + ); + } else { + mcpServers = {}; + } + } + const sandboxConfig = await loadSandboxConfig(settings, argv); return new Config({