diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index 22547ae1..32f13329 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -67,6 +67,7 @@ interface MockServerConfig { getAccessibility: Mock<() => AccessibilitySettings>; getProjectRoot: Mock<() => string | undefined>; getAllGeminiMdFilenames: Mock<() => string[]>; + getUserTier: Mock<() => Promise>; } // Mock @google/gemini-cli-core and its Config class @@ -129,6 +130,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']), setFlashFallbackHandler: vi.fn(), getSessionId: vi.fn(() => 'test-session-id'), + getUserTier: vi.fn().mockResolvedValue(undefined), }; }); return { @@ -155,6 +157,8 @@ vi.mock('./hooks/useAuthCommand', () => ({ openAuthDialog: vi.fn(), handleAuthSelect: vi.fn(), handleAuthHighlight: vi.fn(), + isAuthenticating: false, + cancelAuthentication: vi.fn(), })), })); diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 2a6bf088..4e2c7242 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -139,6 +139,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { const [showPrivacyNotice, setShowPrivacyNotice] = useState(false); const [modelSwitchedFromQuotaError, setModelSwitchedFromQuotaError] = useState(false); + const [userTier, setUserTier] = useState(undefined); const openPrivacyNotice = useCallback(() => { setShowPrivacyNotice(true); @@ -174,6 +175,29 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { } }, [settings.merged.selectedAuthType, openAuthDialog, setAuthError]); + // Sync user tier from config when authentication changes + useEffect(() => { + const syncUserTier = async () => { + try { + const configUserTier = await config.getUserTier(); + if (configUserTier !== userTier) { + setUserTier(configUserTier); + } + } catch (error) { + // Silently fail - this is not critical functionality + // Only log in debug mode to avoid cluttering the console + if (config.getDebugMode()) { + console.debug('Failed to sync user tier:', error); + } + } + }; + + // Only sync when not currently authenticating + if (!isAuthenticating) { + syncUserTier(); + } + }, [config, userTier, isAuthenticating]); + const { isEditorDialogOpen, openEditorDialog, @@ -254,9 +278,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { ): Promise => { let message: string; - // For quota errors, assume FREE tier (safe default) - only show upgrade messaging to free tier users - // TODO: Get actual user tier from config when available - const userTier = undefined; // Defaults to FREE tier behavior + // Use actual user tier if available, otherwise default to FREE tier behavior (safe default) const isPaidTier = userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD; @@ -320,7 +342,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { }; config.setFlashFallbackHandler(flashFallbackHandler); - }, [config, addItem]); + }, [config, addItem, userTier]); const { handleSlashCommand, diff --git a/packages/core/src/code_assist/server.ts b/packages/core/src/code_assist/server.ts index 01fd2462..fe8661f1 100644 --- a/packages/core/src/code_assist/server.ts +++ b/packages/core/src/code_assist/server.ts @@ -23,6 +23,7 @@ import { } from '@google/genai'; import * as readline from 'readline'; import { ContentGenerator } from '../core/contentGenerator.js'; +import { UserTierId } from './types.js'; import { CaCountTokenResponse, CaGenerateContentResponse, @@ -59,6 +60,8 @@ export const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com'; export const CODE_ASSIST_API_VERSION = 'v1internal'; export class CodeAssistServer implements ContentGenerator { + private userTier: UserTierId | undefined = undefined; + constructor( readonly client: OAuth2Client, readonly projectId?: string, @@ -253,6 +256,40 @@ export class CodeAssistServer implements ContentGenerator { })(); } + async getTier(): Promise { + if (this.userTier === undefined) { + await this.detectUserTier(); + } + return this.userTier; + } + + private async detectUserTier(): Promise { + try { + // Reset user tier when detection runs + this.userTier = undefined; + + // Only attempt tier detection if we have a project ID + if (this.projectId) { + const loadRes = await this.loadCodeAssist({ + cloudaicompanionProject: this.projectId, + metadata: { + ideType: 'IDE_UNSPECIFIED', + platform: 'PLATFORM_UNSPECIFIED', + pluginType: 'GEMINI', + duetProject: this.projectId, + }, + }); + if (loadRes.currentTier) { + this.userTier = loadRes.currentTier.id; + } + } + } catch (error) { + // Silently fail - this is not critical functionality + // We'll default to FREE tier behavior if tier detection fails + console.debug('User tier detection failed:', error); + } + } + getMethodUrl(method: string): string { const endpoint = process.env.CODE_ASSIST_ENDPOINT ?? CODE_ASSIST_ENDPOINT; return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`; diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 12767133..dc85c61a 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -11,6 +11,7 @@ import { ContentGeneratorConfig, createContentGeneratorConfig, } from '../core/contentGenerator.js'; +import { UserTierId } from '../code_assist/types.js'; import { ToolRegistry } from '../tools/tool-registry.js'; import { LSTool } from '../tools/ls.js'; import { ReadFileTool } from '../tools/read-file.js'; @@ -323,6 +324,14 @@ export class Config { return this.quotaErrorOccurred; } + async getUserTier(): Promise { + if (!this.geminiClient) { + return undefined; + } + const generator = this.geminiClient.getContentGenerator(); + return await generator.getTier?.(); + } + getEmbeddingModel(): string { return this.embeddingModel; } diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index fee10fad..d3434c23 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -17,6 +17,7 @@ import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js'; import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; import { Config } from '../config/config.js'; import { getEffectiveModel } from './modelCheck.js'; +import { UserTierId } from '../code_assist/types.js'; /** * Interface abstracting the core functionalities for generating content and counting tokens. @@ -33,6 +34,8 @@ export interface ContentGenerator { countTokens(request: CountTokensParameters): Promise; embedContent(request: EmbedContentParameters): Promise; + + getTier?(): Promise; } export enum AuthType {