Add --allowed_mcp_server_names flag (#3464)

This commit is contained in:
Tyler 2025-07-07 09:45:58 -07:00 committed by GitHub
parent 355fb4ac67
commit 229ae03631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 1 deletions

View File

@ -478,3 +478,80 @@ describe('mergeExcludeTools', () => {
expect(settings).toEqual(originalSettings); expect(settings).toEqual(originalSettings);
}); });
}); });
describe('loadCliConfig with allowed_mcp_server_names', () => {
const originalArgv = process.argv;
const originalEnv = { ...process.env };
beforeEach(() => {
vi.resetAllMocks();
vi.mocked(os.homedir).mockReturnValue('/mock/home/user');
process.env.GEMINI_API_KEY = 'test-api-key';
});
afterEach(() => {
process.argv = originalArgv;
process.env = originalEnv;
vi.restoreAllMocks();
});
const baseSettings: Settings = {
mcpServers: {
server1: { url: 'http://localhost:8080' },
server2: { url: 'http://localhost:8081' },
server3: { url: 'http://localhost:8082' },
},
};
it('should allow all MCP servers if the flag is not provided', async () => {
process.argv = ['node', 'script.js'];
const config = await loadCliConfig(baseSettings, [], 'test-session');
expect(config.getMcpServers()).toEqual(baseSettings.mcpServers);
});
it('should allow only the specified MCP server', async () => {
process.argv = [
'node',
'script.js',
'--allowed_mcp_server_names',
'server1',
];
const config = await loadCliConfig(baseSettings, [], 'test-session');
expect(config.getMcpServers()).toEqual({
server1: { url: 'http://localhost:8080' },
});
});
it('should allow multiple specified MCP servers', async () => {
process.argv = [
'node',
'script.js',
'--allowed_mcp_server_names',
'server1,server3',
];
const config = await loadCliConfig(baseSettings, [], 'test-session');
expect(config.getMcpServers()).toEqual({
server1: { url: 'http://localhost:8080' },
server3: { url: 'http://localhost:8082' },
});
});
it('should handle server names that do not exist', async () => {
process.argv = [
'node',
'script.js',
'--allowed_mcp_server_names',
'server1,server4',
];
const config = await loadCliConfig(baseSettings, [], 'test-session');
expect(config.getMcpServers()).toEqual({
server1: { url: 'http://localhost:8080' },
});
});
it('should allow all MCP servers if the flag is an empty string', async () => {
process.argv = ['node', 'script.js', '--allowed_mcp_server_names', ''];
const config = await loadCliConfig(baseSettings, [], 'test-session');
expect(config.getMcpServers()).toEqual(baseSettings.mcpServers);
});
});

View File

@ -48,6 +48,7 @@ interface CliArgs {
telemetryTarget: string | undefined; telemetryTarget: string | undefined;
telemetryOtlpEndpoint: string | undefined; telemetryOtlpEndpoint: string | undefined;
telemetryLogPrompts: boolean | undefined; telemetryLogPrompts: boolean | undefined;
allowed_mcp_server_names: string | undefined;
} }
async function parseArguments(): Promise<CliArgs> { async function parseArguments(): Promise<CliArgs> {
@ -123,6 +124,10 @@ async function parseArguments(): Promise<CliArgs> {
description: 'Enables checkpointing of file edits', description: 'Enables checkpointing of file edits',
default: false, default: false,
}) })
.option('allowed_mcp_server_names', {
type: 'string',
description: 'Allowed MCP server names',
})
.version(await getCliVersion()) // This will enable the --version flag based on package.json .version(await getCliVersion()) // This will enable the --version flag based on package.json
.alias('v', 'version') .alias('v', 'version')
.help() .help()
@ -186,9 +191,22 @@ export async function loadCliConfig(
extensionContextFilePaths, extensionContextFilePaths,
); );
const mcpServers = mergeMcpServers(settings, extensions); let mcpServers = mergeMcpServers(settings, extensions);
const excludeTools = mergeExcludeTools(settings, extensions); const excludeTools = mergeExcludeTools(settings, extensions);
if (argv.allowed_mcp_server_names) {
const allowedNames = new Set(
argv.allowed_mcp_server_names.split(',').filter(Boolean),
);
if (allowedNames.size > 0) {
mcpServers = Object.fromEntries(
Object.entries(mcpServers).filter(([key]) => allowedNames.has(key)),
);
} else {
mcpServers = {};
}
}
const sandboxConfig = await loadSandboxConfig(settings, argv); const sandboxConfig = await loadSandboxConfig(settings, argv);
return new Config({ return new Config({