diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 543801f0..2d33daa3 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -22,7 +22,7 @@ import { } from '@google/gemini-cli-core'; import { Settings } from './settings.js'; -import { Extension, filterActiveExtensions } from './extension.js'; +import { Extension, annotateActiveExtensions } from './extension.js'; import { getCliVersion } from '../utils/version.js'; import { loadSandboxConfig } from './sandboxConfig.js'; @@ -252,11 +252,15 @@ export async function loadCliConfig( process.env.TERM_PROGRAM === 'vscode' && !process.env.SANDBOX; - const activeExtensions = filterActiveExtensions( + const allExtensions = annotateActiveExtensions( extensions, argv.extensions || [], ); + const activeExtensions = extensions.filter( + (_, i) => allExtensions[i].isActive, + ); + // Set the context filename in the server's memoryTool module BEFORE loading memory // TODO(b/343434939): This is a bit of a hack. The contextFileName should ideally be passed // directly to the Config constructor in core, and have core handle setGeminiMdFilename. @@ -283,6 +287,7 @@ export async function loadCliConfig( let mcpServers = mergeMcpServers(settings, activeExtensions); const excludeTools = mergeExcludeTools(settings, activeExtensions); + const blockedMcpServers: Array<{ name: string; extensionName: string }> = []; if (!argv.allowedMcpServerNames) { if (settings.allowMCPServers) { @@ -308,9 +313,24 @@ export async function loadCliConfig( const allowedNames = new Set(argv.allowedMcpServerNames.filter(Boolean)); if (allowedNames.size > 0) { mcpServers = Object.fromEntries( - Object.entries(mcpServers).filter(([key]) => allowedNames.has(key)), + Object.entries(mcpServers).filter(([key, server]) => { + const isAllowed = allowedNames.has(key); + if (!isAllowed) { + blockedMcpServers.push({ + name: key, + extensionName: server.extensionName || '', + }); + } + return isAllowed; + }), ); } else { + blockedMcpServers.push( + ...Object.entries(mcpServers).map(([key, server]) => ({ + name: key, + extensionName: server.extensionName || '', + })), + ); mcpServers = {}; } } @@ -403,10 +423,8 @@ export async function loadCliConfig( maxSessionTurns: settings.maxSessionTurns ?? -1, experimentalAcp: argv.experimentalAcp || false, listExtensions: argv.listExtensions || false, - activeExtensions: activeExtensions.map((e) => ({ - name: e.config.name, - version: e.config.version, - })), + extensions: allExtensions, + blockedMcpServers, noBrowser: !!process.env.NO_BROWSER, summarizeToolOutput: settings.summarizeToolOutput, ideMode, @@ -424,7 +442,10 @@ function mergeMcpServers(settings: Settings, extensions: Extension[]) { ); return; } - mcpServers[key] = server; + mcpServers[key] = { + ...server, + extensionName: extension.config.name, + }; }, ); } diff --git a/packages/cli/src/config/extension.test.ts b/packages/cli/src/config/extension.test.ts index ab68e3f5..6b2a3f83 100644 --- a/packages/cli/src/config/extension.test.ts +++ b/packages/cli/src/config/extension.test.ts @@ -11,7 +11,7 @@ import * as path from 'path'; import { EXTENSIONS_CONFIG_FILENAME, EXTENSIONS_DIRECTORY_NAME, - filterActiveExtensions, + annotateActiveExtensions, loadExtensions, } from './extension.js'; @@ -86,42 +86,52 @@ describe('loadExtensions', () => { }); }); -describe('filterActiveExtensions', () => { +describe('annotateActiveExtensions', () => { const extensions = [ { config: { name: 'ext1', version: '1.0.0' }, contextFiles: [] }, { config: { name: 'ext2', version: '1.0.0' }, contextFiles: [] }, { config: { name: 'ext3', version: '1.0.0' }, contextFiles: [] }, ]; - it('should return all extensions if no enabled extensions are provided', () => { - const activeExtensions = filterActiveExtensions(extensions, []); + it('should mark all extensions as active if no enabled extensions are provided', () => { + const activeExtensions = annotateActiveExtensions(extensions, []); expect(activeExtensions).toHaveLength(3); + expect(activeExtensions.every((e) => e.isActive)).toBe(true); }); - it('should return only the enabled extensions', () => { - const activeExtensions = filterActiveExtensions(extensions, [ + it('should mark only the enabled extensions as active', () => { + const activeExtensions = annotateActiveExtensions(extensions, [ 'ext1', 'ext3', ]); - expect(activeExtensions).toHaveLength(2); - expect(activeExtensions.some((e) => e.config.name === 'ext1')).toBe(true); - expect(activeExtensions.some((e) => e.config.name === 'ext3')).toBe(true); + expect(activeExtensions).toHaveLength(3); + expect(activeExtensions.find((e) => e.name === 'ext1')?.isActive).toBe( + true, + ); + expect(activeExtensions.find((e) => e.name === 'ext2')?.isActive).toBe( + false, + ); + expect(activeExtensions.find((e) => e.name === 'ext3')?.isActive).toBe( + true, + ); }); - it('should return no extensions when "none" is provided', () => { - const activeExtensions = filterActiveExtensions(extensions, ['none']); - expect(activeExtensions).toHaveLength(0); + it('should mark all extensions as inactive when "none" is provided', () => { + const activeExtensions = annotateActiveExtensions(extensions, ['none']); + expect(activeExtensions).toHaveLength(3); + expect(activeExtensions.every((e) => !e.isActive)).toBe(true); }); it('should handle case-insensitivity', () => { - const activeExtensions = filterActiveExtensions(extensions, ['EXT1']); - expect(activeExtensions).toHaveLength(1); - expect(activeExtensions[0].config.name).toBe('ext1'); + const activeExtensions = annotateActiveExtensions(extensions, ['EXT1']); + expect(activeExtensions.find((e) => e.name === 'ext1')?.isActive).toBe( + true, + ); }); it('should log an error for unknown extensions', () => { - const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {}); - filterActiveExtensions(extensions, ['ext4']); + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + annotateActiveExtensions(extensions, ['ext4']); expect(consoleSpy).toHaveBeenCalledWith('Extension not found: ext4'); consoleSpy.mockRestore(); }); diff --git a/packages/cli/src/config/extension.ts b/packages/cli/src/config/extension.ts index aa540cfd..adefec29 100644 --- a/packages/cli/src/config/extension.ts +++ b/packages/cli/src/config/extension.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { MCPServerConfig } from '@google/gemini-cli-core'; +import { MCPServerConfig, GeminiCLIExtension } from '@google/gemini-cli-core'; import * as fs from 'fs'; import * as path from 'path'; import * as os from 'os'; @@ -34,9 +34,6 @@ export function loadExtensions(workspaceDir: string): Extension[] { const uniqueExtensions = new Map(); for (const extension of allExtensions) { if (!uniqueExtensions.has(extension.config.name)) { - console.log( - `Loading extension: ${extension.config.name} (version: ${extension.config.version})`, - ); uniqueExtensions.set(extension.config.name, extension); } } @@ -113,12 +110,18 @@ function getContextFileNames(config: ExtensionConfig): string[] { return config.contextFileName; } -export function filterActiveExtensions( +export function annotateActiveExtensions( extensions: Extension[], enabledExtensionNames: string[], -): Extension[] { +): GeminiCLIExtension[] { + const annotatedExtensions: GeminiCLIExtension[] = []; + if (enabledExtensionNames.length === 0) { - return extensions; + return extensions.map((extension) => ({ + name: extension.config.name, + version: extension.config.version, + isActive: true, + })); } const lowerCaseEnabledExtensions = new Set( @@ -129,31 +132,33 @@ export function filterActiveExtensions( lowerCaseEnabledExtensions.size === 1 && lowerCaseEnabledExtensions.has('none') ) { - if (extensions.length > 0) { - console.log('All extensions are disabled.'); - } - return []; + return extensions.map((extension) => ({ + name: extension.config.name, + version: extension.config.version, + isActive: false, + })); } - const activeExtensions: Extension[] = []; const notFoundNames = new Set(lowerCaseEnabledExtensions); for (const extension of extensions) { const lowerCaseName = extension.config.name.toLowerCase(); - if (lowerCaseEnabledExtensions.has(lowerCaseName)) { - console.log( - `Activated extension: ${extension.config.name} (version: ${extension.config.version})`, - ); - activeExtensions.push(extension); + const isActive = lowerCaseEnabledExtensions.has(lowerCaseName); + + if (isActive) { notFoundNames.delete(lowerCaseName); - } else { - console.log(`Disabled extension: ${extension.config.name}`); } + + annotatedExtensions.push({ + name: extension.config.name, + version: extension.config.version, + isActive, + }); } for (const requestedName of notFoundNames) { - console.log(`Extension not found: ${requestedName}`); + console.error(`Extension not found: ${requestedName}`); } - return activeExtensions; + return annotatedExtensions; } diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index 0c18b042..ed4418e9 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -58,6 +58,12 @@ interface MockServerConfig { getToolCallCommand: Mock<() => string | undefined>; getMcpServerCommand: Mock<() => string | undefined>; getMcpServers: Mock<() => Record | undefined>; + getExtensions: Mock< + () => Array<{ name: string; version: string; isActive: boolean }> + >; + getBlockedMcpServers: Mock< + () => Array<{ name: string; extensionName: string }> + >; getUserAgent: Mock<() => string>; getUserMemory: Mock<() => string>; setUserMemory: Mock<(newUserMemory: string) => void>; @@ -118,6 +124,8 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { getToolCallCommand: vi.fn(() => opts.toolCallCommand), getMcpServerCommand: vi.fn(() => opts.mcpServerCommand), getMcpServers: vi.fn(() => opts.mcpServers), + getExtensions: vi.fn(() => []), + getBlockedMcpServers: vi.fn(() => []), getUserAgent: vi.fn(() => opts.userAgent || 'test-agent'), getUserMemory: vi.fn(() => opts.userMemory || ''), setUserMemory: vi.fn(), diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 5d8ab39d..782e2ff8 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -886,6 +886,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { geminiMdFileCount={geminiMdFileCount} contextFileNames={contextFileNames} mcpServers={config.getMcpServers()} + blockedMcpServers={config.getBlockedMcpServers()} showToolDescriptions={showToolDescriptions} /> )} diff --git a/packages/cli/src/ui/commands/extensionsCommand.test.ts b/packages/cli/src/ui/commands/extensionsCommand.test.ts index a989d9b0..0a69e01c 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.test.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.test.ts @@ -17,7 +17,7 @@ describe('extensionsCommand', () => { mockContext = createMockCommandContext({ services: { config: { - getActiveExtensions: () => [], + getExtensions: () => [], }, }, }); @@ -36,13 +36,14 @@ describe('extensionsCommand', () => { it('should list active extensions when they are found', async () => { const mockExtensions = [ - { name: 'ext-one', version: '1.0.0' }, - { name: 'ext-two', version: '2.1.0' }, + { name: 'ext-one', version: '1.0.0', isActive: true }, + { name: 'ext-two', version: '2.1.0', isActive: true }, + { name: 'ext-three', version: '3.0.0', isActive: false }, ]; mockContext = createMockCommandContext({ services: { config: { - getActiveExtensions: () => mockExtensions, + getExtensions: () => mockExtensions, }, }, }); diff --git a/packages/cli/src/ui/commands/extensionsCommand.ts b/packages/cli/src/ui/commands/extensionsCommand.ts index 87d23afb..09241e5f 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.ts @@ -11,7 +11,9 @@ export const extensionsCommand: SlashCommand = { name: 'extensions', description: 'list active extensions', action: async (context: CommandContext): Promise => { - const activeExtensions = context.services.config?.getActiveExtensions(); + const activeExtensions = context.services.config + ?.getExtensions() + .filter((ext) => ext.isActive); if (!activeExtensions || activeExtensions.length === 0) { context.ui.addItem( { diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index 0a8d8306..f23cf3ab 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -63,6 +63,7 @@ describe('mcpCommand', () => { let mockConfig: { getToolRegistry: ReturnType; getMcpServers: ReturnType; + getBlockedMcpServers: ReturnType; }; beforeEach(() => { @@ -83,6 +84,7 @@ describe('mcpCommand', () => { getAllTools: vi.fn().mockReturnValue([]), }), getMcpServers: vi.fn().mockReturnValue({}), + getBlockedMcpServers: vi.fn().mockReturnValue([]), }; mockContext = createMockCommandContext({ @@ -419,6 +421,61 @@ describe('mcpCommand', () => { ); } }); + + it('should display the extension name for servers from extensions', async () => { + const mockMcpServers = { + server1: { command: 'cmd1', extensionName: 'my-extension' }, + }; + mockConfig.getMcpServers = vi.fn().mockReturnValue(mockMcpServers); + + const result = await mcpCommand.action!(mockContext, ''); + + expect(isMessageAction(result)).toBe(true); + if (isMessageAction(result)) { + const message = result.content; + expect(message).toContain('server1 (from my-extension)'); + } + }); + + it('should display blocked MCP servers', async () => { + mockConfig.getMcpServers = vi.fn().mockReturnValue({}); + const blockedServers = [ + { name: 'blocked-server', extensionName: 'my-extension' }, + ]; + mockConfig.getBlockedMcpServers = vi.fn().mockReturnValue(blockedServers); + + const result = await mcpCommand.action!(mockContext, ''); + + expect(isMessageAction(result)).toBe(true); + if (isMessageAction(result)) { + const message = result.content; + expect(message).toContain( + '🔴 \u001b[1mblocked-server (from my-extension)\u001b[0m - Blocked', + ); + } + }); + + it('should display both active and blocked servers correctly', async () => { + const mockMcpServers = { + server1: { command: 'cmd1', extensionName: 'my-extension' }, + }; + mockConfig.getMcpServers = vi.fn().mockReturnValue(mockMcpServers); + const blockedServers = [ + { name: 'blocked-server', extensionName: 'another-extension' }, + ]; + mockConfig.getBlockedMcpServers = vi.fn().mockReturnValue(blockedServers); + + const result = await mcpCommand.action!(mockContext, ''); + + expect(isMessageAction(result)).toBe(true); + if (isMessageAction(result)) { + const message = result.content; + expect(message).toContain('server1 (from my-extension)'); + expect(message).toContain( + '🔴 \u001b[1mblocked-server (from another-extension)\u001b[0m - Blocked', + ); + } + }); }); describe('schema functionality', () => { diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 5ff77c4b..891227b0 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -49,8 +49,9 @@ const getMcpStatus = async ( const mcpServers = config.getMcpServers() || {}; const serverNames = Object.keys(mcpServers); + const blockedMcpServers = config.getBlockedMcpServers() || []; - if (serverNames.length === 0) { + if (serverNames.length === 0 && blockedMcpServers.length === 0) { const docsUrl = 'https://goo.gle/gemini-cli-docs-mcp'; if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') { return { @@ -118,9 +119,13 @@ const getMcpStatus = async ( // Get server description if available const server = mcpServers[serverName]; + let serverDisplayName = serverName; + if (server.extensionName) { + serverDisplayName += ` (from ${server.extensionName})`; + } // Format server header with bold formatting and status - message += `${statusIndicator} \u001b[1m${serverName}\u001b[0m - ${statusText}`; + message += `${statusIndicator} \u001b[1m${serverDisplayName}\u001b[0m - ${statusText}`; // Add tool count with conditional messaging if (status === MCPServerStatus.CONNECTED) { @@ -192,6 +197,14 @@ const getMcpStatus = async ( message += '\n'; } + for (const server of blockedMcpServers) { + let serverDisplayName = server.name; + if (server.extensionName) { + serverDisplayName += ` (from ${server.extensionName})`; + } + message += `🔴 \u001b[1m${serverDisplayName}\u001b[0m - Blocked\n\n`; + } + // Add helpful tips when no arguments are provided if (showTips) { message += '\n'; diff --git a/packages/cli/src/ui/components/ContextSummaryDisplay.tsx b/packages/cli/src/ui/components/ContextSummaryDisplay.tsx index 00a95e19..314e8ebd 100644 --- a/packages/cli/src/ui/components/ContextSummaryDisplay.tsx +++ b/packages/cli/src/ui/components/ContextSummaryDisplay.tsx @@ -13,6 +13,7 @@ interface ContextSummaryDisplayProps { geminiMdFileCount: number; contextFileNames: string[]; mcpServers?: Record; + blockedMcpServers?: Array<{ name: string; extensionName: string }>; showToolDescriptions?: boolean; } @@ -20,11 +21,17 @@ export const ContextSummaryDisplay: React.FC = ({ geminiMdFileCount, contextFileNames, mcpServers, + blockedMcpServers, showToolDescriptions, }) => { const mcpServerCount = Object.keys(mcpServers || {}).length; + const blockedMcpServerCount = blockedMcpServers?.length || 0; - if (geminiMdFileCount === 0 && mcpServerCount === 0) { + if ( + geminiMdFileCount === 0 && + mcpServerCount === 0 && + blockedMcpServerCount === 0 + ) { return ; // Render an empty space to reserve height } @@ -39,10 +46,27 @@ export const ContextSummaryDisplay: React.FC = ({ }`; })(); - const mcpText = - mcpServerCount > 0 - ? `${mcpServerCount} MCP server${mcpServerCount > 1 ? 's' : ''}` - : ''; + const mcpText = (() => { + if (mcpServerCount === 0 && blockedMcpServerCount === 0) { + return ''; + } + + const parts = []; + if (mcpServerCount > 0) { + parts.push( + `${mcpServerCount} MCP server${mcpServerCount > 1 ? 's' : ''}`, + ); + } + + if (blockedMcpServerCount > 0) { + let blockedText = `${blockedMcpServerCount} blocked`; + if (mcpServerCount === 0) { + blockedText += ` MCP server${blockedMcpServerCount > 1 ? 's' : ''}`; + } + parts.push(blockedText); + } + return parts.join(', '); + })(); let summaryText = 'Using '; if (geminiMdText) { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 9d47fb08..f81b3e32 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -71,9 +71,10 @@ export interface TelemetrySettings { logPrompts?: boolean; } -export interface ActiveExtension { +export interface GeminiCLIExtension { name: string; version: string; + isActive: boolean; } export class MCPServerConfig { @@ -97,6 +98,7 @@ export class MCPServerConfig { readonly description?: string, readonly includeTools?: string[], readonly excludeTools?: string[], + readonly extensionName?: string, ) {} } @@ -147,7 +149,8 @@ export interface ConfigParameters { maxSessionTurns?: number; experimentalAcp?: boolean; listExtensions?: boolean; - activeExtensions?: ActiveExtension[]; + extensions?: GeminiCLIExtension[]; + blockedMcpServers?: Array<{ name: string; extensionName: string }>; noBrowser?: boolean; summarizeToolOutput?: Record; ideMode?: boolean; @@ -194,7 +197,11 @@ export class Config { private modelSwitchedDuringSession: boolean = false; private readonly maxSessionTurns: number; private readonly listExtensions: boolean; - private readonly _activeExtensions: ActiveExtension[]; + private readonly _extensions: GeminiCLIExtension[]; + private readonly _blockedMcpServers: Array<{ + name: string; + extensionName: string; + }>; flashFallbackHandler?: FlashFallbackHandler; private quotaErrorOccurred: boolean = false; private readonly summarizeToolOutput: @@ -245,7 +252,8 @@ export class Config { this.maxSessionTurns = params.maxSessionTurns ?? -1; this.experimentalAcp = params.experimentalAcp ?? false; this.listExtensions = params.listExtensions ?? false; - this._activeExtensions = params.activeExtensions ?? []; + this._extensions = params.extensions ?? []; + this._blockedMcpServers = params.blockedMcpServers ?? []; this.noBrowser = params.noBrowser ?? false; this.summarizeToolOutput = params.summarizeToolOutput; this.ideMode = params.ideMode ?? false; @@ -505,8 +513,12 @@ export class Config { return this.listExtensions; } - getActiveExtensions(): ActiveExtension[] { - return this._activeExtensions; + getExtensions(): GeminiCLIExtension[] { + return this._extensions; + } + + getBlockedMcpServers(): Array<{ name: string; extensionName: string }> { + return this._blockedMcpServers; } getNoBrowser(): boolean {