Added in proper checks for customer tiers in 429/Quota error messaging (#3863)

Co-authored-by: Ioannis Papapanagiotou <iduckhd@hotmail.com>
This commit is contained in:
Bryan Morgan 2025-07-11 11:25:30 -04:00 committed by GitHub
parent c9e1e6d3bd
commit cdbe2fffd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 79 additions and 4 deletions

View File

@ -67,6 +67,7 @@ interface MockServerConfig {
getAccessibility: Mock<() => AccessibilitySettings>; getAccessibility: Mock<() => AccessibilitySettings>;
getProjectRoot: Mock<() => string | undefined>; getProjectRoot: Mock<() => string | undefined>;
getAllGeminiMdFilenames: Mock<() => string[]>; getAllGeminiMdFilenames: Mock<() => string[]>;
getUserTier: Mock<() => Promise<string | undefined>>;
} }
// Mock @google/gemini-cli-core and its Config class // Mock @google/gemini-cli-core and its Config class
@ -129,6 +130,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']), getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']),
setFlashFallbackHandler: vi.fn(), setFlashFallbackHandler: vi.fn(),
getSessionId: vi.fn(() => 'test-session-id'), getSessionId: vi.fn(() => 'test-session-id'),
getUserTier: vi.fn().mockResolvedValue(undefined),
}; };
}); });
return { return {
@ -155,6 +157,8 @@ vi.mock('./hooks/useAuthCommand', () => ({
openAuthDialog: vi.fn(), openAuthDialog: vi.fn(),
handleAuthSelect: vi.fn(), handleAuthSelect: vi.fn(),
handleAuthHighlight: vi.fn(), handleAuthHighlight: vi.fn(),
isAuthenticating: false,
cancelAuthentication: vi.fn(),
})), })),
})); }));

View File

@ -139,6 +139,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
const [showPrivacyNotice, setShowPrivacyNotice] = useState<boolean>(false); const [showPrivacyNotice, setShowPrivacyNotice] = useState<boolean>(false);
const [modelSwitchedFromQuotaError, setModelSwitchedFromQuotaError] = const [modelSwitchedFromQuotaError, setModelSwitchedFromQuotaError] =
useState<boolean>(false); useState<boolean>(false);
const [userTier, setUserTier] = useState<UserTierId | undefined>(undefined);
const openPrivacyNotice = useCallback(() => { const openPrivacyNotice = useCallback(() => {
setShowPrivacyNotice(true); setShowPrivacyNotice(true);
@ -174,6 +175,29 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
} }
}, [settings.merged.selectedAuthType, openAuthDialog, setAuthError]); }, [settings.merged.selectedAuthType, openAuthDialog, setAuthError]);
// Sync user tier from config when authentication changes
useEffect(() => {
const syncUserTier = async () => {
try {
const configUserTier = await config.getUserTier();
if (configUserTier !== userTier) {
setUserTier(configUserTier);
}
} catch (error) {
// Silently fail - this is not critical functionality
// Only log in debug mode to avoid cluttering the console
if (config.getDebugMode()) {
console.debug('Failed to sync user tier:', error);
}
}
};
// Only sync when not currently authenticating
if (!isAuthenticating) {
syncUserTier();
}
}, [config, userTier, isAuthenticating]);
const { const {
isEditorDialogOpen, isEditorDialogOpen,
openEditorDialog, openEditorDialog,
@ -254,9 +278,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
): Promise<boolean> => { ): Promise<boolean> => {
let message: string; let message: string;
// For quota errors, assume FREE tier (safe default) - only show upgrade messaging to free tier users // Use actual user tier if available, otherwise default to FREE tier behavior (safe default)
// TODO: Get actual user tier from config when available
const userTier = undefined; // Defaults to FREE tier behavior
const isPaidTier = const isPaidTier =
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD; userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
@ -320,7 +342,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
}; };
config.setFlashFallbackHandler(flashFallbackHandler); config.setFlashFallbackHandler(flashFallbackHandler);
}, [config, addItem]); }, [config, addItem, userTier]);
const { const {
handleSlashCommand, handleSlashCommand,

View File

@ -23,6 +23,7 @@ import {
} from '@google/genai'; } from '@google/genai';
import * as readline from 'readline'; import * as readline from 'readline';
import { ContentGenerator } from '../core/contentGenerator.js'; import { ContentGenerator } from '../core/contentGenerator.js';
import { UserTierId } from './types.js';
import { import {
CaCountTokenResponse, CaCountTokenResponse,
CaGenerateContentResponse, CaGenerateContentResponse,
@ -59,6 +60,8 @@ export const CODE_ASSIST_ENDPOINT = 'https://cloudcode-pa.googleapis.com';
export const CODE_ASSIST_API_VERSION = 'v1internal'; export const CODE_ASSIST_API_VERSION = 'v1internal';
export class CodeAssistServer implements ContentGenerator { export class CodeAssistServer implements ContentGenerator {
private userTier: UserTierId | undefined = undefined;
constructor( constructor(
readonly client: OAuth2Client, readonly client: OAuth2Client,
readonly projectId?: string, readonly projectId?: string,
@ -253,6 +256,40 @@ export class CodeAssistServer implements ContentGenerator {
})(); })();
} }
async getTier(): Promise<UserTierId | undefined> {
if (this.userTier === undefined) {
await this.detectUserTier();
}
return this.userTier;
}
private async detectUserTier(): Promise<void> {
try {
// Reset user tier when detection runs
this.userTier = undefined;
// Only attempt tier detection if we have a project ID
if (this.projectId) {
const loadRes = await this.loadCodeAssist({
cloudaicompanionProject: this.projectId,
metadata: {
ideType: 'IDE_UNSPECIFIED',
platform: 'PLATFORM_UNSPECIFIED',
pluginType: 'GEMINI',
duetProject: this.projectId,
},
});
if (loadRes.currentTier) {
this.userTier = loadRes.currentTier.id;
}
}
} catch (error) {
// Silently fail - this is not critical functionality
// We'll default to FREE tier behavior if tier detection fails
console.debug('User tier detection failed:', error);
}
}
getMethodUrl(method: string): string { getMethodUrl(method: string): string {
const endpoint = process.env.CODE_ASSIST_ENDPOINT ?? CODE_ASSIST_ENDPOINT; const endpoint = process.env.CODE_ASSIST_ENDPOINT ?? CODE_ASSIST_ENDPOINT;
return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`; return `${endpoint}/${CODE_ASSIST_API_VERSION}:${method}`;

View File

@ -11,6 +11,7 @@ import {
ContentGeneratorConfig, ContentGeneratorConfig,
createContentGeneratorConfig, createContentGeneratorConfig,
} from '../core/contentGenerator.js'; } from '../core/contentGenerator.js';
import { UserTierId } from '../code_assist/types.js';
import { ToolRegistry } from '../tools/tool-registry.js'; import { ToolRegistry } from '../tools/tool-registry.js';
import { LSTool } from '../tools/ls.js'; import { LSTool } from '../tools/ls.js';
import { ReadFileTool } from '../tools/read-file.js'; import { ReadFileTool } from '../tools/read-file.js';
@ -323,6 +324,14 @@ export class Config {
return this.quotaErrorOccurred; return this.quotaErrorOccurred;
} }
async getUserTier(): Promise<UserTierId | undefined> {
if (!this.geminiClient) {
return undefined;
}
const generator = this.geminiClient.getContentGenerator();
return await generator.getTier?.();
}
getEmbeddingModel(): string { getEmbeddingModel(): string {
return this.embeddingModel; return this.embeddingModel;
} }

View File

@ -17,6 +17,7 @@ import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { getEffectiveModel } from './modelCheck.js'; import { getEffectiveModel } from './modelCheck.js';
import { UserTierId } from '../code_assist/types.js';
/** /**
* Interface abstracting the core functionalities for generating content and counting tokens. * Interface abstracting the core functionalities for generating content and counting tokens.
@ -33,6 +34,8 @@ export interface ContentGenerator {
countTokens(request: CountTokensParameters): Promise<CountTokensResponse>; countTokens(request: CountTokensParameters): Promise<CountTokensResponse>;
embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>; embedContent(request: EmbedContentParameters): Promise<EmbedContentResponse>;
getTier?(): Promise<UserTierId | undefined>;
} }
export enum AuthType { export enum AuthType {