From c0940a194ea002742cb12d88dee9328a0d2da153 Mon Sep 17 00:00:00 2001 From: Billy Biggs Date: Tue, 8 Jul 2025 12:57:34 -0400 Subject: [PATCH] Add a command line option to enable and list extensions (#3191) --- docs/cli/configuration.md | 6 ++ packages/cli/src/config/config.test.ts | 38 +++++++++++++ packages/cli/src/config/config.ts | 34 ++++++++++-- packages/cli/src/config/extension.test.ts | 42 ++++++++++++++ packages/cli/src/config/extension.ts | 55 +++++++++++++++++-- packages/cli/src/gemini.tsx | 8 +++ .../cli/src/ui/hooks/slashCommandProcessor.ts | 28 ++++++++++ packages/core/src/config/config.ts | 19 +++++++ 8 files changed, 220 insertions(+), 10 deletions(-) diff --git a/docs/cli/configuration.md b/docs/cli/configuration.md index 5f0514b0..1b2f9680 100644 --- a/docs/cli/configuration.md +++ b/docs/cli/configuration.md @@ -311,6 +311,12 @@ Arguments passed directly when running the CLI can override other configurations - Enables logging of prompts for telemetry. See [telemetry](../telemetry.md) for more information. - **`--checkpointing`**: - Enables [checkpointing](./commands.md#checkpointing-commands). +- **`--extensions `** (**`-e `**): + - Specifies a list of extensions to use for the session. If not provided, all available extensions are used. + - Use the special term `gemini -e none` to disable all extensions. + - Example: `gemini -e my-extension -e my-other-extension` +- **`--list-extensions`** (**`-l`**): + - Lists all available extensions and exits. - **`--version`**: - Displays the version of the CLI. diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index ca5c9fdf..1ea48760 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -555,3 +555,41 @@ describe('loadCliConfig with allowed-mcp-server-names', () => { expect(config.getMcpServers()).toEqual(baseSettings.mcpServers); }); }); + +describe('loadCliConfig extensions', () => { + const mockExtensions: Extension[] = [ + { + config: { name: 'ext1', version: '1.0.0' }, + contextFiles: ['/path/to/ext1.md'], + }, + { + config: { name: 'ext2', version: '1.0.0' }, + contextFiles: ['/path/to/ext2.md'], + }, + ]; + + it('should not filter extensions if --extensions flag is not used', async () => { + process.argv = ['node', 'script.js']; + const settings: Settings = {}; + const config = await loadCliConfig( + settings, + mockExtensions, + 'test-session', + ); + expect(config.getExtensionContextFilePaths()).toEqual([ + '/path/to/ext1.md', + '/path/to/ext2.md', + ]); + }); + + it('should filter extensions if --extensions flag is used', async () => { + process.argv = ['node', 'script.js', '--extensions', 'ext1']; + const settings: Settings = {}; + const config = await loadCliConfig( + settings, + mockExtensions, + 'test-session', + ); + expect(config.getExtensionContextFilePaths()).toEqual(['/path/to/ext1.md']); + }); +}); diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index b79d20fc..7d1af7e1 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -20,7 +20,7 @@ import { } from '@google/gemini-cli-core'; import { Settings } from './settings.js'; -import { Extension } from './extension.js'; +import { Extension, filterActiveExtensions } from './extension.js'; import { getCliVersion } from '../utils/version.js'; import { loadSandboxConfig } from './sandboxConfig.js'; @@ -49,6 +49,8 @@ interface CliArgs { telemetryOtlpEndpoint: string | undefined; telemetryLogPrompts: boolean | undefined; 'allowed-mcp-server-names': string | undefined; + extensions: string[] | undefined; + listExtensions: boolean | undefined; } async function parseArguments(): Promise { @@ -133,6 +135,18 @@ async function parseArguments(): Promise { type: 'string', description: 'Allowed MCP server names', }) + .option('extensions', { + alias: 'e', + type: 'array', + string: true, + description: + 'A list of extensions to use. If not provided, all extensions are used.', + }) + .option('list-extensions', { + alias: 'l', + type: 'boolean', + description: 'List all available extensions and exit.', + }) .version(await getCliVersion()) // This will enable the --version flag based on package.json .alias('v', 'version') .help() @@ -174,6 +188,11 @@ export async function loadCliConfig( const argv = await parseArguments(); const debugMode = argv.debug || false; + const activeExtensions = filterActiveExtensions( + extensions, + argv.extensions || [], + ); + // Set the context filename in the server's memoryTool module BEFORE loading memory // TODO(b/343434939): This is a bit of a hack. The contextFileName should ideally be passed // directly to the Config constructor in core, and have core handle setGeminiMdFilename. @@ -185,7 +204,9 @@ export async function loadCliConfig( setServerGeminiMdFilename(getCurrentGeminiMdFilename()); } - const extensionContextFilePaths = extensions.flatMap((e) => e.contextFiles); + const extensionContextFilePaths = activeExtensions.flatMap( + (e) => e.contextFiles, + ); const fileService = new FileDiscoveryService(process.cwd()); // Call the (now wrapper) loadHierarchicalGeminiMemory which calls the server's version @@ -196,8 +217,8 @@ export async function loadCliConfig( extensionContextFilePaths, ); - let mcpServers = mergeMcpServers(settings, extensions); - const excludeTools = mergeExcludeTools(settings, extensions); + let mcpServers = mergeMcpServers(settings, activeExtensions); + const excludeTools = mergeExcludeTools(settings, activeExtensions); if (argv['allowed-mcp-server-names']) { const allowedNames = new Set( @@ -262,6 +283,11 @@ export async function loadCliConfig( bugCommand: settings.bugCommand, model: argv.model!, extensionContextFilePaths, + listExtensions: argv.listExtensions || false, + activeExtensions: activeExtensions.map((e) => ({ + name: e.config.name, + version: e.config.version, + })), }); } diff --git a/packages/cli/src/config/extension.test.ts b/packages/cli/src/config/extension.test.ts index 7d299c78..ab68e3f5 100644 --- a/packages/cli/src/config/extension.test.ts +++ b/packages/cli/src/config/extension.test.ts @@ -11,6 +11,7 @@ import * as path from 'path'; import { EXTENSIONS_CONFIG_FILENAME, EXTENSIONS_DIRECTORY_NAME, + filterActiveExtensions, loadExtensions, } from './extension.js'; @@ -85,6 +86,47 @@ describe('loadExtensions', () => { }); }); +describe('filterActiveExtensions', () => { + const extensions = [ + { config: { name: 'ext1', version: '1.0.0' }, contextFiles: [] }, + { config: { name: 'ext2', version: '1.0.0' }, contextFiles: [] }, + { config: { name: 'ext3', version: '1.0.0' }, contextFiles: [] }, + ]; + + it('should return all extensions if no enabled extensions are provided', () => { + const activeExtensions = filterActiveExtensions(extensions, []); + expect(activeExtensions).toHaveLength(3); + }); + + it('should return only the enabled extensions', () => { + const activeExtensions = filterActiveExtensions(extensions, [ + 'ext1', + 'ext3', + ]); + expect(activeExtensions).toHaveLength(2); + expect(activeExtensions.some((e) => e.config.name === 'ext1')).toBe(true); + expect(activeExtensions.some((e) => e.config.name === 'ext3')).toBe(true); + }); + + it('should return no extensions when "none" is provided', () => { + const activeExtensions = filterActiveExtensions(extensions, ['none']); + expect(activeExtensions).toHaveLength(0); + }); + + it('should handle case-insensitivity', () => { + const activeExtensions = filterActiveExtensions(extensions, ['EXT1']); + expect(activeExtensions).toHaveLength(1); + expect(activeExtensions[0].config.name).toBe('ext1'); + }); + + it('should log an error for unknown extensions', () => { + const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {}); + filterActiveExtensions(extensions, ['ext4']); + expect(consoleSpy).toHaveBeenCalledWith('Extension not found: ext4'); + consoleSpy.mockRestore(); + }); +}); + function createExtension( extensionsDir: string, name: string, diff --git a/packages/cli/src/config/extension.ts b/packages/cli/src/config/extension.ts index 57e6632b..aa540cfd 100644 --- a/packages/cli/src/config/extension.ts +++ b/packages/cli/src/config/extension.ts @@ -31,19 +31,17 @@ export function loadExtensions(workspaceDir: string): Extension[] { ...loadExtensionsFromDir(os.homedir()), ]; - const uniqueExtensions: Extension[] = []; - const seenNames = new Set(); + const uniqueExtensions = new Map(); for (const extension of allExtensions) { - if (!seenNames.has(extension.config.name)) { + if (!uniqueExtensions.has(extension.config.name)) { console.log( `Loading extension: ${extension.config.name} (version: ${extension.config.version})`, ); - uniqueExtensions.push(extension); - seenNames.add(extension.config.name); + uniqueExtensions.set(extension.config.name, extension); } } - return uniqueExtensions; + return Array.from(uniqueExtensions.values()); } function loadExtensionsFromDir(dir: string): Extension[] { @@ -114,3 +112,48 @@ function getContextFileNames(config: ExtensionConfig): string[] { } return config.contextFileName; } + +export function filterActiveExtensions( + extensions: Extension[], + enabledExtensionNames: string[], +): Extension[] { + if (enabledExtensionNames.length === 0) { + return extensions; + } + + const lowerCaseEnabledExtensions = new Set( + enabledExtensionNames.map((e) => e.trim().toLowerCase()), + ); + + if ( + lowerCaseEnabledExtensions.size === 1 && + lowerCaseEnabledExtensions.has('none') + ) { + if (extensions.length > 0) { + console.log('All extensions are disabled.'); + } + return []; + } + + const activeExtensions: Extension[] = []; + const notFoundNames = new Set(lowerCaseEnabledExtensions); + + for (const extension of extensions) { + const lowerCaseName = extension.config.name.toLowerCase(); + if (lowerCaseEnabledExtensions.has(lowerCaseName)) { + console.log( + `Activated extension: ${extension.config.name} (version: ${extension.config.version})`, + ); + activeExtensions.push(extension); + notFoundNames.delete(lowerCaseName); + } else { + console.log(`Disabled extension: ${extension.config.name}`); + } + } + + for (const requestedName of notFoundNames) { + console.log(`Extension not found: ${requestedName}`); + } + + return activeExtensions; +} diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 7e86a8ca..89f5eb3a 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -103,6 +103,14 @@ export async function main() { const extensions = loadExtensions(workspaceRoot); const config = await loadCliConfig(settings.merged, extensions, sessionId); + if (config.getListExtensions()) { + console.log('Installed extensions:'); + for (const extension of extensions) { + console.log(`- ${extension.config.name}`); + } + process.exit(0); + } + // Set a default auth type if one isn't set for a couple of known cases. if (!settings.merged.selectedAuthType) { if (process.env.GEMINI_API_KEY) { diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index c174b8a4..66cf4e39 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -493,6 +493,34 @@ export const useSlashCommandProcessor = ( }); }, }, + { + name: 'extensions', + description: 'list active extensions', + action: async () => { + const activeExtensions = config?.getActiveExtensions(); + if (!activeExtensions || activeExtensions.length === 0) { + addMessage({ + type: MessageType.INFO, + content: 'No active extensions.', + timestamp: new Date(), + }); + return; + } + + let message = 'Active extensions:\n\n'; + for (const ext of activeExtensions) { + message += ` - \u001b[36m${ext.name} (v${ext.version})\u001b[0m\n`; + } + // Make sure to reset any ANSI formatting at the end to prevent it from affecting the terminal + message += '\u001b[0m'; + + addMessage({ + type: MessageType.INFO, + content: message, + timestamp: new Date(), + }); + }, + }, { name: 'tools', description: 'list available Gemini CLI tools', diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index ca0714f0..2cea70ca 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -66,6 +66,11 @@ export interface TelemetrySettings { logPrompts?: boolean; } +export interface ActiveExtension { + name: string; + version: string; +} + export class MCPServerConfig { constructor( // For stdio transport @@ -133,6 +138,8 @@ export interface ConfigParameters { bugCommand?: BugCommandSettings; model: string; extensionContextFilePaths?: string[]; + listExtensions?: boolean; + activeExtensions?: ActiveExtension[]; } export class Config { @@ -172,6 +179,8 @@ export class Config { private readonly model: string; private readonly extensionContextFilePaths: string[]; private modelSwitchedDuringSession: boolean = false; + private readonly listExtensions: boolean; + private readonly _activeExtensions: ActiveExtension[]; flashFallbackHandler?: FlashFallbackHandler; constructor(params: ConfigParameters) { @@ -214,6 +223,8 @@ export class Config { this.bugCommand = params.bugCommand; this.model = params.model; this.extensionContextFilePaths = params.extensionContextFilePaths ?? []; + this.listExtensions = params.listExtensions ?? false; + this._activeExtensions = params.activeExtensions ?? []; if (params.contextFileName) { setGeminiMdFilename(params.contextFileName); @@ -446,6 +457,14 @@ export class Config { return this.extensionContextFilePaths; } + getListExtensions(): boolean { + return this.listExtensions; + } + + getActiveExtensions(): ActiveExtension[] { + return this._activeExtensions; + } + async getGitService(): Promise { if (!this.gitService) { this.gitService = new GitService(this.targetDir);