Add --allowed_mcp_server_names flag (#3464)
This commit is contained in:
parent
355fb4ac67
commit
229ae03631
|
@ -478,3 +478,80 @@ describe('mergeExcludeTools', () => {
|
||||||
expect(settings).toEqual(originalSettings);
|
expect(settings).toEqual(originalSettings);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('loadCliConfig with allowed_mcp_server_names', () => {
|
||||||
|
const originalArgv = process.argv;
|
||||||
|
const originalEnv = { ...process.env };
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.resetAllMocks();
|
||||||
|
vi.mocked(os.homedir).mockReturnValue('/mock/home/user');
|
||||||
|
process.env.GEMINI_API_KEY = 'test-api-key';
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
process.argv = originalArgv;
|
||||||
|
process.env = originalEnv;
|
||||||
|
vi.restoreAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
const baseSettings: Settings = {
|
||||||
|
mcpServers: {
|
||||||
|
server1: { url: 'http://localhost:8080' },
|
||||||
|
server2: { url: 'http://localhost:8081' },
|
||||||
|
server3: { url: 'http://localhost:8082' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
it('should allow all MCP servers if the flag is not provided', async () => {
|
||||||
|
process.argv = ['node', 'script.js'];
|
||||||
|
const config = await loadCliConfig(baseSettings, [], 'test-session');
|
||||||
|
expect(config.getMcpServers()).toEqual(baseSettings.mcpServers);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow only the specified MCP server', async () => {
|
||||||
|
process.argv = [
|
||||||
|
'node',
|
||||||
|
'script.js',
|
||||||
|
'--allowed_mcp_server_names',
|
||||||
|
'server1',
|
||||||
|
];
|
||||||
|
const config = await loadCliConfig(baseSettings, [], 'test-session');
|
||||||
|
expect(config.getMcpServers()).toEqual({
|
||||||
|
server1: { url: 'http://localhost:8080' },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow multiple specified MCP servers', async () => {
|
||||||
|
process.argv = [
|
||||||
|
'node',
|
||||||
|
'script.js',
|
||||||
|
'--allowed_mcp_server_names',
|
||||||
|
'server1,server3',
|
||||||
|
];
|
||||||
|
const config = await loadCliConfig(baseSettings, [], 'test-session');
|
||||||
|
expect(config.getMcpServers()).toEqual({
|
||||||
|
server1: { url: 'http://localhost:8080' },
|
||||||
|
server3: { url: 'http://localhost:8082' },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle server names that do not exist', async () => {
|
||||||
|
process.argv = [
|
||||||
|
'node',
|
||||||
|
'script.js',
|
||||||
|
'--allowed_mcp_server_names',
|
||||||
|
'server1,server4',
|
||||||
|
];
|
||||||
|
const config = await loadCliConfig(baseSettings, [], 'test-session');
|
||||||
|
expect(config.getMcpServers()).toEqual({
|
||||||
|
server1: { url: 'http://localhost:8080' },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should allow all MCP servers if the flag is an empty string', async () => {
|
||||||
|
process.argv = ['node', 'script.js', '--allowed_mcp_server_names', ''];
|
||||||
|
const config = await loadCliConfig(baseSettings, [], 'test-session');
|
||||||
|
expect(config.getMcpServers()).toEqual(baseSettings.mcpServers);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
|
@ -48,6 +48,7 @@ interface CliArgs {
|
||||||
telemetryTarget: string | undefined;
|
telemetryTarget: string | undefined;
|
||||||
telemetryOtlpEndpoint: string | undefined;
|
telemetryOtlpEndpoint: string | undefined;
|
||||||
telemetryLogPrompts: boolean | undefined;
|
telemetryLogPrompts: boolean | undefined;
|
||||||
|
allowed_mcp_server_names: string | undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
async function parseArguments(): Promise<CliArgs> {
|
async function parseArguments(): Promise<CliArgs> {
|
||||||
|
@ -123,6 +124,10 @@ async function parseArguments(): Promise<CliArgs> {
|
||||||
description: 'Enables checkpointing of file edits',
|
description: 'Enables checkpointing of file edits',
|
||||||
default: false,
|
default: false,
|
||||||
})
|
})
|
||||||
|
.option('allowed_mcp_server_names', {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Allowed MCP server names',
|
||||||
|
})
|
||||||
.version(await getCliVersion()) // This will enable the --version flag based on package.json
|
.version(await getCliVersion()) // This will enable the --version flag based on package.json
|
||||||
.alias('v', 'version')
|
.alias('v', 'version')
|
||||||
.help()
|
.help()
|
||||||
|
@ -186,9 +191,22 @@ export async function loadCliConfig(
|
||||||
extensionContextFilePaths,
|
extensionContextFilePaths,
|
||||||
);
|
);
|
||||||
|
|
||||||
const mcpServers = mergeMcpServers(settings, extensions);
|
let mcpServers = mergeMcpServers(settings, extensions);
|
||||||
const excludeTools = mergeExcludeTools(settings, extensions);
|
const excludeTools = mergeExcludeTools(settings, extensions);
|
||||||
|
|
||||||
|
if (argv.allowed_mcp_server_names) {
|
||||||
|
const allowedNames = new Set(
|
||||||
|
argv.allowed_mcp_server_names.split(',').filter(Boolean),
|
||||||
|
);
|
||||||
|
if (allowedNames.size > 0) {
|
||||||
|
mcpServers = Object.fromEntries(
|
||||||
|
Object.entries(mcpServers).filter(([key]) => allowedNames.has(key)),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
mcpServers = {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const sandboxConfig = await loadSandboxConfig(settings, argv);
|
const sandboxConfig = await loadSandboxConfig(settings, argv);
|
||||||
|
|
||||||
return new Config({
|
return new Config({
|
||||||
|
|
Loading…
Reference in New Issue