[JUNE 25] Permanent failover to Flash model for OAuth users after persistent 429 errors (#1376)

Co-authored-by: Scott Densmore <scottdensmore@mac.com>
This commit is contained in:
Bryan Morgan 2025-06-24 18:48:55 -04:00 committed by GitHub
parent 4bf18da2b0
commit e356949d3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 837 additions and 12 deletions

View File

@ -125,6 +125,7 @@ vi.mock('@gemini-cli/core', async (importOriginal) => {
getGeminiClient: vi.fn(() => ({})), getGeminiClient: vi.fn(() => ({})),
getCheckpointingEnabled: vi.fn(() => opts.checkpointing ?? true), getCheckpointingEnabled: vi.fn(() => opts.checkpointing ?? true),
getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']), getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']),
setFlashFallbackHandler: vi.fn(),
}; };
}); });
return { return {

View File

@ -115,6 +115,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
const [editorError, setEditorError] = useState<string | null>(null); const [editorError, setEditorError] = useState<string | null>(null);
const [footerHeight, setFooterHeight] = useState<number>(0); const [footerHeight, setFooterHeight] = useState<number>(0);
const [corgiMode, setCorgiMode] = useState(false); const [corgiMode, setCorgiMode] = useState(false);
const [currentModel, setCurrentModel] = useState(config.getModel());
const [shellModeActive, setShellModeActive] = useState(false); const [shellModeActive, setShellModeActive] = useState(false);
const [showErrorDetails, setShowErrorDetails] = useState<boolean>(false); const [showErrorDetails, setShowErrorDetails] = useState<boolean>(false);
const [showToolDescriptions, setShowToolDescriptions] = const [showToolDescriptions, setShowToolDescriptions] =
@ -214,6 +215,42 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
} }
}, [config, addItem]); }, [config, addItem]);
// Watch for model changes (e.g., from Flash fallback)
useEffect(() => {
const checkModelChange = () => {
const configModel = config.getModel();
if (configModel !== currentModel) {
setCurrentModel(configModel);
}
};
// Check immediately and then periodically
checkModelChange();
const interval = setInterval(checkModelChange, 1000); // Check every second
return () => clearInterval(interval);
}, [config, currentModel]);
// Set up Flash fallback handler
useEffect(() => {
const flashFallbackHandler = async (
currentModel: string,
fallbackModel: string,
): Promise<boolean> => {
// Add message to UI history
addItem(
{
type: MessageType.INFO,
text: `⚡ Rate limiting detected. Automatically switching from ${currentModel} to ${fallbackModel} for faster responses for the remainder of this session.`,
},
Date.now(),
);
return true; // Always accept the fallback
};
config.setFlashFallbackHandler(flashFallbackHandler);
}, [config, addItem]);
const { const {
handleSlashCommand, handleSlashCommand,
slashCommands, slashCommands,
@ -787,7 +824,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
</Box> </Box>
)} )}
<Footer <Footer
model={config.getModel()} model={currentModel}
targetDir={config.getTargetDir()} targetDir={config.getTargetDir()}
debugMode={config.getDebugMode()} debugMode={config.getDebugMode()}
branchName={branchName} branchName={branchName}

View File

@ -171,7 +171,7 @@ export const useSlashCommandProcessor = (
[addMessage], [addMessage],
); );
const savedChatTags = async function () { const savedChatTags = useCallback(async () => {
const geminiDir = config?.getProjectTempDir(); const geminiDir = config?.getProjectTempDir();
if (!geminiDir) { if (!geminiDir) {
return []; return [];
@ -186,7 +186,7 @@ export const useSlashCommandProcessor = (
} catch (_err) { } catch (_err) {
return []; return [];
} }
}; }, [config]);
const slashCommands: SlashCommand[] = useMemo(() => { const slashCommands: SlashCommand[] = useMemo(() => {
const commands: SlashCommand[] = [ const commands: SlashCommand[] = [
@ -992,6 +992,7 @@ Add any other context about the problem here.
addMemoryAction, addMemoryAction,
addMessage, addMessage,
toggleCorgiMode, toggleCorgiMode,
savedChatTags,
config, config,
showToolDescriptions, showToolDescriptions,
session, session,

View File

@ -35,7 +35,10 @@ import {
TelemetryTarget, TelemetryTarget,
StartSessionEvent, StartSessionEvent,
} from '../telemetry/index.js'; } from '../telemetry/index.js';
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from './models.js'; import {
DEFAULT_GEMINI_EMBEDDING_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
} from './models.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
export enum ApprovalMode { export enum ApprovalMode {
@ -85,6 +88,11 @@ export interface SandboxConfig {
image: string; image: string;
} }
export type FlashFallbackHandler = (
currentModel: string,
fallbackModel: string,
) => Promise<boolean>;
export interface ConfigParameters { export interface ConfigParameters {
sessionId: string; sessionId: string;
embeddingModel?: string; embeddingModel?: string;
@ -156,6 +164,8 @@ export class Config {
private readonly bugCommand: BugCommandSettings | undefined; private readonly bugCommand: BugCommandSettings | undefined;
private readonly model: string; private readonly model: string;
private readonly extensionContextFilePaths: string[]; private readonly extensionContextFilePaths: string[];
private modelSwitchedDuringSession: boolean = false;
flashFallbackHandler?: FlashFallbackHandler;
constructor(params: ConfigParameters) { constructor(params: ConfigParameters) {
this.sessionId = params.sessionId; this.sessionId = params.sessionId;
@ -216,9 +226,24 @@ export class Config {
} }
async refreshAuth(authMethod: AuthType) { async refreshAuth(authMethod: AuthType) {
// Check if this is actually a switch to a different auth method
const previousAuthType = this.contentGeneratorConfig?.authType;
const _isAuthMethodSwitch =
previousAuthType && previousAuthType !== authMethod;
// Always use the original default model when switching auth methods
// This ensures users don't stay on Flash after switching between auth types
// and allows API key users to get proper fallback behavior from getEffectiveModel
const modelToUse = this.model; // Use the original default model
// Temporarily clear contentGeneratorConfig to prevent getModel() from returning
// the previous session's model (which might be Flash)
this.contentGeneratorConfig = undefined!;
const contentConfig = await createContentGeneratorConfig( const contentConfig = await createContentGeneratorConfig(
this.getModel(), modelToUse,
authMethod, authMethod,
this,
); );
const gc = new GeminiClient(this); const gc = new GeminiClient(this);
@ -226,6 +251,11 @@ export class Config {
this.toolRegistry = await createToolRegistry(this); this.toolRegistry = await createToolRegistry(this);
await gc.initialize(contentConfig); await gc.initialize(contentConfig);
this.contentGeneratorConfig = contentConfig; this.contentGeneratorConfig = contentConfig;
// Reset the session flag since we're explicitly changing auth and using default model
this.modelSwitchedDuringSession = false;
// Note: In the future, we may want to reset any cached state when switching auth methods
} }
getSessionId(): string { getSessionId(): string {
@ -240,6 +270,28 @@ export class Config {
return this.contentGeneratorConfig?.model || this.model; return this.contentGeneratorConfig?.model || this.model;
} }
setModel(newModel: string): void {
if (this.contentGeneratorConfig) {
this.contentGeneratorConfig.model = newModel;
this.modelSwitchedDuringSession = true;
}
}
isModelSwitchedDuringSession(): boolean {
return this.modelSwitchedDuringSession;
}
resetModelToDefault(): void {
if (this.contentGeneratorConfig) {
this.contentGeneratorConfig.model = this.model; // Reset to the original default model
this.modelSwitchedDuringSession = false;
}
}
setFlashFallbackHandler(handler: FlashFallbackHandler): void {
this.flashFallbackHandler = handler;
}
getEmbeddingModel(): string { getEmbeddingModel(): string {
return this.embeddingModel; return this.embeddingModel;
} }
@ -445,3 +497,6 @@ export function createToolRegistry(config: Config): Promise<ToolRegistry> {
return registry; return registry;
})(); })();
} }
// Export model constants for use in CLI
export { DEFAULT_GEMINI_FLASH_MODEL };

View File

@ -0,0 +1,139 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach } from 'vitest';
import { Config } from './config.js';
import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL } from './models.js';
describe('Flash Model Fallback Configuration', () => {
let config: Config;
beforeEach(() => {
config = new Config({
sessionId: 'test-session',
targetDir: '/test',
debugMode: false,
cwd: '/test',
model: DEFAULT_GEMINI_MODEL,
});
// Initialize contentGeneratorConfig for testing
(
config as unknown as { contentGeneratorConfig: unknown }
).contentGeneratorConfig = {
model: DEFAULT_GEMINI_MODEL,
authType: 'oauth-personal',
};
});
describe('setModel', () => {
it('should update the model and mark as switched during session', () => {
expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL);
expect(config.isModelSwitchedDuringSession()).toBe(false);
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
expect(config.isModelSwitchedDuringSession()).toBe(true);
});
it('should handle multiple model switches during session', () => {
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(config.isModelSwitchedDuringSession()).toBe(true);
config.setModel('gemini-1.5-pro');
expect(config.getModel()).toBe('gemini-1.5-pro');
expect(config.isModelSwitchedDuringSession()).toBe(true);
});
it('should only mark as switched if contentGeneratorConfig exists', () => {
// Create config without initializing contentGeneratorConfig
const newConfig = new Config({
sessionId: 'test-session-2',
targetDir: '/test',
debugMode: false,
cwd: '/test',
model: DEFAULT_GEMINI_MODEL,
});
// Should not crash when contentGeneratorConfig is undefined
newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(newConfig.isModelSwitchedDuringSession()).toBe(false);
});
});
describe('getModel', () => {
it('should return contentGeneratorConfig model if available', () => {
// Simulate initialized content generator config
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
});
it('should fallback to initial model if contentGeneratorConfig is not available', () => {
// Test with fresh config where contentGeneratorConfig might not be set
const newConfig = new Config({
sessionId: 'test-session-2',
targetDir: '/test',
debugMode: false,
cwd: '/test',
model: 'custom-model',
});
expect(newConfig.getModel()).toBe('custom-model');
});
});
describe('isModelSwitchedDuringSession', () => {
it('should start as false for new session', () => {
expect(config.isModelSwitchedDuringSession()).toBe(false);
});
it('should remain false if no model switch occurs', () => {
// Perform other operations that don't involve model switching
expect(config.isModelSwitchedDuringSession()).toBe(false);
});
it('should persist switched state throughout session', () => {
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(config.isModelSwitchedDuringSession()).toBe(true);
// Should remain true even after getting model
config.getModel();
expect(config.isModelSwitchedDuringSession()).toBe(true);
});
});
describe('resetModelToDefault', () => {
it('should reset model to default and clear session switch flag', () => {
// Switch to Flash first
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
expect(config.isModelSwitchedDuringSession()).toBe(true);
// Reset to default
config.resetModelToDefault();
// Should be back to default with flag cleared
expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL);
expect(config.isModelSwitchedDuringSession()).toBe(false);
});
it('should handle case where contentGeneratorConfig is not initialized', () => {
// Create config without initializing contentGeneratorConfig
const newConfig = new Config({
sessionId: 'test-session-2',
targetDir: '/test',
debugMode: false,
cwd: '/test',
model: DEFAULT_GEMINI_MODEL,
});
// Should not crash when contentGeneratorConfig is undefined
expect(() => newConfig.resetModelToDefault()).not.toThrow();
expect(newConfig.isModelSwitchedDuringSession()).toBe(false);
});
});
});

View File

@ -20,6 +20,7 @@ import { Turn } from './turn.js';
import { getCoreSystemPrompt } from './prompts.js'; import { getCoreSystemPrompt } from './prompts.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
import { setSimulate429 } from '../utils/testUtils.js';
// --- Mocks --- // --- Mocks ---
const mockChatCreateFn = vi.fn(); const mockChatCreateFn = vi.fn();
@ -68,6 +69,9 @@ describe('Gemini Client (client.ts)', () => {
beforeEach(async () => { beforeEach(async () => {
vi.resetAllMocks(); vi.resetAllMocks();
// Disable 429 simulation for tests
setSimulate429(false);
// Set up the mock for GoogleGenAI constructor and its methods // Set up the mock for GoogleGenAI constructor and its methods
const MockedGoogleGenAI = vi.mocked(GoogleGenAI); const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
MockedGoogleGenAI.mockImplementation(() => { MockedGoogleGenAI.mockImplementation(() => {

View File

@ -38,6 +38,7 @@ import {
} from './contentGenerator.js'; } from './contentGenerator.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici'; import { ProxyAgent, setGlobalDispatcher } from 'undici';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { AuthType } from './contentGenerator.js';
function isThinkingSupported(model: string) { function isThinkingSupported(model: string) {
if (model.startsWith('gemini-2.5')) return true; if (model.startsWith('gemini-2.5')) return true;
@ -276,7 +277,11 @@ export class GeminiClient {
contents, contents,
}); });
const result = await retryWithBackoff(apiCall); const result = await retryWithBackoff(apiCall, {
onPersistent429: async (authType?: string) =>
await this.handleFlashFallback(authType),
authType: this.config.getContentGeneratorConfig()?.authType,
});
const text = getResponseText(result); const text = getResponseText(result);
if (!text) { if (!text) {
@ -360,7 +365,11 @@ export class GeminiClient {
contents, contents,
}); });
const result = await retryWithBackoff(apiCall); const result = await retryWithBackoff(apiCall, {
onPersistent429: async (authType?: string) =>
await this.handleFlashFallback(authType),
authType: this.config.getContentGeneratorConfig()?.authType,
});
return result; return result;
} catch (error: unknown) { } catch (error: unknown) {
if (abortSignal.aborted) { if (abortSignal.aborted) {
@ -489,4 +498,43 @@ export class GeminiClient {
} }
: null; : null;
} }
/**
* Handles fallback to Flash model when persistent 429 errors occur for OAuth users.
* Uses a fallback handler if provided by the config, otherwise returns null.
*/
private async handleFlashFallback(authType?: string): Promise<string | null> {
// Only handle fallback for OAuth users
if (
authType !== AuthType.LOGIN_WITH_GOOGLE_PERSONAL &&
authType !== AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE
) {
return null;
}
const currentModel = this.model;
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
if (currentModel === fallbackModel) {
return null;
}
// Check if config has a fallback handler (set by CLI package)
const fallbackHandler = this.config.flashFallbackHandler;
if (typeof fallbackHandler === 'function') {
try {
const accepted = await fallbackHandler(currentModel, fallbackModel);
if (accepted) {
this.config.setModel(fallbackModel);
this.model = fallbackModel;
return fallbackModel;
}
} catch (error) {
console.warn('Flash fallback handler failed:', error);
}
}
return null;
}
} }

View File

@ -51,14 +51,18 @@ export type ContentGeneratorConfig = {
export async function createContentGeneratorConfig( export async function createContentGeneratorConfig(
model: string | undefined, model: string | undefined,
authType: AuthType | undefined, authType: AuthType | undefined,
config?: { getModel?: () => string },
): Promise<ContentGeneratorConfig> { ): Promise<ContentGeneratorConfig> {
const geminiApiKey = process.env.GEMINI_API_KEY; const geminiApiKey = process.env.GEMINI_API_KEY;
const googleApiKey = process.env.GOOGLE_API_KEY; const googleApiKey = process.env.GOOGLE_API_KEY;
const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT; const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT;
const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION; const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION;
// Use runtime model from config if available, otherwise fallback to parameter or default
const effectiveModel = config?.getModel?.() || model || DEFAULT_GEMINI_MODEL;
const contentGeneratorConfig: ContentGeneratorConfig = { const contentGeneratorConfig: ContentGeneratorConfig = {
model: model || DEFAULT_GEMINI_MODEL, model: effectiveModel,
authType, authType,
}; };

View File

@ -14,6 +14,7 @@ import {
} from '@google/genai'; } from '@google/genai';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { setSimulate429 } from '../utils/testUtils.js';
// Mocks // Mocks
const mockModelsModule = { const mockModelsModule = {
@ -29,6 +30,12 @@ const mockConfig = {
getTelemetryLogPromptsEnabled: () => true, getTelemetryLogPromptsEnabled: () => true,
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
getDebugMode: () => false, getDebugMode: () => false,
getContentGeneratorConfig: () => ({
authType: 'oauth-personal',
model: 'test-model',
}),
setModel: vi.fn(),
flashFallbackHandler: undefined,
} as unknown as Config; } as unknown as Config;
describe('GeminiChat', () => { describe('GeminiChat', () => {
@ -38,6 +45,8 @@ describe('GeminiChat', () => {
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
// Disable 429 simulation for tests
setSimulate429(false);
// Reset history for each test by creating a new instance // Reset history for each test by creating a new instance
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []); chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
}); });

View File

@ -18,7 +18,7 @@ import {
} from '@google/genai'; } from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js'; import { retryWithBackoff } from '../utils/retry.js';
import { isFunctionResponse } from '../utils/messageInspectors.js'; import { isFunctionResponse } from '../utils/messageInspectors.js';
import { ContentGenerator } from './contentGenerator.js'; import { ContentGenerator, AuthType } from './contentGenerator.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { import {
logApiRequest, logApiRequest,
@ -34,6 +34,7 @@ import {
ApiRequestEvent, ApiRequestEvent,
ApiResponseEvent, ApiResponseEvent,
} from '../telemetry/types.js'; } from '../telemetry/types.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
/** /**
* Returns true if the response is valid, false otherwise. * Returns true if the response is valid, false otherwise.
@ -181,6 +182,44 @@ export class GeminiChat {
); );
} }
/**
* Handles fallback to Flash model when persistent 429 errors occur for OAuth users.
* Uses a fallback handler if provided by the config, otherwise returns null.
*/
private async handleFlashFallback(authType?: string): Promise<string | null> {
// Only handle fallback for OAuth users
if (
authType !== AuthType.LOGIN_WITH_GOOGLE_PERSONAL &&
authType !== AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE
) {
return null;
}
const currentModel = this.model;
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
if (currentModel === fallbackModel) {
return null;
}
// Check if config has a fallback handler (set by CLI package)
const fallbackHandler = this.config.flashFallbackHandler;
if (typeof fallbackHandler === 'function') {
try {
const accepted = await fallbackHandler(currentModel, fallbackModel);
if (accepted) {
this.config.setModel(fallbackModel);
return fallbackModel;
}
} catch (error) {
console.warn('Flash fallback handler failed:', error);
}
}
return null;
}
/** /**
* Sends a message to the model and returns the response. * Sends a message to the model and returns the response.
* *
@ -315,6 +354,9 @@ export class GeminiChat {
} }
return false; // Don't retry other errors by default return false; // Don't retry other errors by default
}, },
onPersistent429: async (authType?: string) =>
await this.handleFlashFallback(authType),
authType: this.config.getContentGeneratorConfig()?.authType,
}); });
// Resolve the internal tracking of send completion promise - `sendPromise` // Resolve the internal tracking of send completion promise - `sendPromise`

View File

@ -0,0 +1,144 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { Config } from '../config/config.js';
import {
setSimulate429,
disableSimulationAfterFallback,
shouldSimulate429,
createSimulated429Error,
resetRequestCounter,
} from './testUtils.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { retryWithBackoff } from './retry.js';
import { AuthType } from '../core/contentGenerator.js';
describe('Flash Fallback Integration', () => {
let config: Config;
beforeEach(() => {
config = new Config({
sessionId: 'test-session',
targetDir: '/test',
debugMode: false,
cwd: '/test',
model: 'gemini-2.5-pro',
});
// Reset simulation state for each test
setSimulate429(false);
resetRequestCounter();
});
it('should automatically accept fallback', async () => {
// Set up a minimal flash fallback handler for testing
const flashFallbackHandler = async (): Promise<boolean> => true;
config.setFlashFallbackHandler(flashFallbackHandler);
// Call the handler directly to test
const result = await config.flashFallbackHandler!(
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
// Verify it automatically accepts
expect(result).toBe(true);
});
it('should trigger fallback after 3 consecutive 429 errors for OAuth users', async () => {
let fallbackCalled = false;
let fallbackModel = '';
// Mock function that simulates exactly 3 429 errors, then succeeds after fallback
const mockApiCall = vi
.fn()
.mockRejectedValueOnce(createSimulated429Error())
.mockRejectedValueOnce(createSimulated429Error())
.mockRejectedValueOnce(createSimulated429Error())
.mockResolvedValueOnce('success after fallback');
// Mock fallback handler
const mockFallbackHandler = vi.fn(async (_authType?: string) => {
fallbackCalled = true;
fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
return fallbackModel;
});
// Test with OAuth personal auth type, with maxAttempts = 3 to ensure fallback triggers
const result = await retryWithBackoff(mockApiCall, {
maxAttempts: 3,
initialDelayMs: 1,
maxDelayMs: 10,
shouldRetry: (error: Error) => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockFallbackHandler,
authType: AuthType.LOGIN_WITH_GOOGLE_PERSONAL,
});
// Verify fallback was triggered
expect(fallbackCalled).toBe(true);
expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL);
expect(mockFallbackHandler).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE_PERSONAL,
);
expect(result).toBe('success after fallback');
// Should have: 3 failures, then fallback triggered, then 1 success after retry reset
expect(mockApiCall).toHaveBeenCalledTimes(4);
});
it('should not trigger fallback for API key users', async () => {
let fallbackCalled = false;
// Mock function that simulates 429 errors
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
// Mock fallback handler
const mockFallbackHandler = vi.fn(async () => {
fallbackCalled = true;
return DEFAULT_GEMINI_FLASH_MODEL;
});
// Test with API key auth type - should not trigger fallback
try {
await retryWithBackoff(mockApiCall, {
maxAttempts: 5,
initialDelayMs: 10,
maxDelayMs: 100,
shouldRetry: (error: Error) => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockFallbackHandler,
authType: AuthType.USE_GEMINI, // API key auth type
});
} catch (error) {
// Expected to throw after max attempts
expect((error as Error).message).toContain('Rate limit exceeded');
}
// Verify fallback was NOT triggered for API key users
expect(fallbackCalled).toBe(false);
expect(mockFallbackHandler).not.toHaveBeenCalled();
});
it('should properly disable simulation state after fallback', () => {
// Enable simulation
setSimulate429(true);
// Verify simulation is enabled
expect(shouldSimulate429()).toBe(true);
// Disable simulation after fallback
disableSimulationAfterFallback();
// Verify simulation is now disabled
expect(shouldSimulate429()).toBe(false);
});
});

View File

@ -7,6 +7,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { retryWithBackoff } from './retry.js'; import { retryWithBackoff } from './retry.js';
import { setSimulate429 } from './testUtils.js';
// Define an interface for the error with a status property // Define an interface for the error with a status property
interface HttpError extends Error { interface HttpError extends Error {
@ -42,10 +43,15 @@ class NonRetryableError extends Error {
describe('retryWithBackoff', () => { describe('retryWithBackoff', () => {
beforeEach(() => { beforeEach(() => {
vi.useFakeTimers(); vi.useFakeTimers();
// Disable 429 simulation for tests
setSimulate429(false);
// Suppress unhandled promise rejection warnings for tests that expect errors
console.warn = vi.fn();
}); });
afterEach(() => { afterEach(() => {
vi.restoreAllMocks(); vi.restoreAllMocks();
vi.useRealTimers();
}); });
it('should return the result on the first attempt if successful', async () => { it('should return the result on the first attempt if successful', async () => {
@ -231,4 +237,197 @@ describe('retryWithBackoff', () => {
expect(d).toBeLessThanOrEqual(100 * 1.3); expect(d).toBeLessThanOrEqual(100 * 1.3);
}); });
}); });
describe('Flash model fallback for OAuth users', () => {
it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackOccurred) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-personal',
});
// Advance all timers to complete retries
await vi.runAllTimersAsync();
// Should succeed after fallback
await expect(promise).resolves.toBe('success');
// Verify callback was called with correct auth type
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
// Should retry again after fallback
expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback
});
it('should trigger fallback for OAuth enterprise users after persistent 429 errors', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackOccurred) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-enterprise',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
expect(fallbackCallback).toHaveBeenCalledWith('oauth-enterprise');
});
it('should NOT trigger fallback for API key users', async () => {
const fallbackCallback = vi.fn();
const mockFn = vi.fn(async () => {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'gemini-api-key',
});
// Handle the promise properly to avoid unhandled rejections
const resultPromise = promise.catch((error) => error);
await vi.runAllTimersAsync();
const result = await resultPromise;
// Should fail after all retries without fallback
expect(result).toBeInstanceOf(Error);
expect(result.message).toBe('Rate limit exceeded');
// Callback should not be called for API key users
expect(fallbackCallback).not.toHaveBeenCalled();
});
it('should reset attempt counter and continue after successful fallback', async () => {
let fallbackCalled = false;
const fallbackCallback = vi.fn().mockImplementation(async () => {
fallbackCalled = true;
return 'gemini-2.5-flash';
});
const mockFn = vi.fn().mockImplementation(async () => {
if (!fallbackCalled) {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
return 'success';
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'oauth-personal',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
expect(fallbackCallback).toHaveBeenCalledOnce();
});
it('should continue with original error if fallback is rejected', async () => {
const fallbackCallback = vi.fn().mockResolvedValue(null); // User rejected fallback
const mockFn = vi.fn(async () => {
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 3,
initialDelayMs: 100,
onPersistent429: fallbackCallback,
authType: 'oauth-personal',
});
// Handle the promise properly to avoid unhandled rejections
const resultPromise = promise.catch((error) => error);
await vi.runAllTimersAsync();
const result = await resultPromise;
// Should fail with original error when fallback is rejected
expect(result).toBeInstanceOf(Error);
expect(result.message).toBe('Rate limit exceeded');
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
});
it('should handle mixed error types (only count consecutive 429s)', async () => {
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
let attempts = 0;
let fallbackOccurred = false;
const mockFn = vi.fn().mockImplementation(async () => {
attempts++;
if (fallbackOccurred) {
return 'success';
}
if (attempts === 1) {
// First attempt: 500 error (resets consecutive count)
const error: HttpError = new Error('Server error');
error.status = 500;
throw error;
} else {
// Remaining attempts: 429 errors
const error: HttpError = new Error('Rate limit exceeded');
error.status = 429;
throw error;
}
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 5,
initialDelayMs: 100,
onPersistent429: async (authType?: string) => {
fallbackOccurred = true;
return await fallbackCallback(authType);
},
authType: 'oauth-personal',
});
await vi.runAllTimersAsync();
await expect(promise).resolves.toBe('success');
// Should trigger fallback after 4 consecutive 429s (attempts 2-5)
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
});
});
}); });

View File

@ -4,11 +4,15 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { AuthType } from '../core/contentGenerator.js';
export interface RetryOptions { export interface RetryOptions {
maxAttempts: number; maxAttempts: number;
initialDelayMs: number; initialDelayMs: number;
maxDelayMs: number; maxDelayMs: number;
shouldRetry: (error: Error) => boolean; shouldRetry: (error: Error) => boolean;
onPersistent429?: (authType?: string) => Promise<string | null>;
authType?: string;
} }
const DEFAULT_RETRY_OPTIONS: RetryOptions = { const DEFAULT_RETRY_OPTIONS: RetryOptions = {
@ -59,29 +63,69 @@ export async function retryWithBackoff<T>(
fn: () => Promise<T>, fn: () => Promise<T>,
options?: Partial<RetryOptions>, options?: Partial<RetryOptions>,
): Promise<T> { ): Promise<T> {
const { maxAttempts, initialDelayMs, maxDelayMs, shouldRetry } = { const {
maxAttempts,
initialDelayMs,
maxDelayMs,
shouldRetry,
onPersistent429,
authType,
} = {
...DEFAULT_RETRY_OPTIONS, ...DEFAULT_RETRY_OPTIONS,
...options, ...options,
}; };
let attempt = 0; let attempt = 0;
let currentDelay = initialDelayMs; let currentDelay = initialDelayMs;
let consecutive429Count = 0;
while (attempt < maxAttempts) { while (attempt < maxAttempts) {
attempt++; attempt++;
try { try {
return await fn(); return await fn();
} catch (error) { } catch (error) {
const errorStatus = getErrorStatus(error);
// Track consecutive 429 errors
if (errorStatus === 429) {
consecutive429Count++;
} else {
consecutive429Count = 0;
}
// Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetry(error as Error)) { if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
// If we have persistent 429s and a fallback callback for OAuth
if (
consecutive429Count >= 3 &&
onPersistent429 &&
(authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL ||
authType === AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE)
) {
try {
const fallbackModel = await onPersistent429(authType);
if (fallbackModel) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
continue;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
throw error; throw error;
} }
const { delayDurationMs, errorStatus } = getDelayDurationAndStatus(error); const { delayDurationMs, errorStatus: delayErrorStatus } =
getDelayDurationAndStatus(error);
if (delayDurationMs > 0) { if (delayDurationMs > 0) {
// Respect Retry-After header if present and parsed // Respect Retry-After header if present and parsed
console.warn( console.warn(
`Attempt ${attempt} failed with status ${errorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`, `Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
error, error,
); );
await delay(delayDurationMs); await delay(delayDurationMs);

View File

@ -0,0 +1,87 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Testing utilities for simulating 429 errors in unit tests
*/
let requestCounter = 0;
let simulate429Enabled = false;
let simulate429AfterRequests = 0;
let simulate429ForAuthType: string | undefined;
let fallbackOccurred = false;
/**
* Check if we should simulate a 429 error for the current request
*/
export function shouldSimulate429(authType?: string): boolean {
if (!simulate429Enabled || fallbackOccurred) {
return false;
}
// If auth type filter is set, only simulate for that auth type
if (simulate429ForAuthType && authType !== simulate429ForAuthType) {
return false;
}
requestCounter++;
// If afterRequests is set, only simulate after that many requests
if (simulate429AfterRequests > 0) {
return requestCounter > simulate429AfterRequests;
}
// Otherwise, simulate for every request
return true;
}
/**
* Reset the request counter (useful for tests)
*/
export function resetRequestCounter(): void {
requestCounter = 0;
}
/**
* Disable 429 simulation after successful fallback
*/
export function disableSimulationAfterFallback(): void {
fallbackOccurred = true;
}
/**
* Create a simulated 429 error response
*/
export function createSimulated429Error(): Error {
const error = new Error('Rate limit exceeded (simulated)') as Error & {
status: number;
};
error.status = 429;
return error;
}
/**
* Reset simulation state when switching auth methods
*/
export function resetSimulationState(): void {
fallbackOccurred = false;
resetRequestCounter();
}
/**
* Enable/disable 429 simulation programmatically (for tests)
*/
export function setSimulate429(
enabled: boolean,
afterRequests = 0,
forAuthType?: string,
): void {
simulate429Enabled = enabled;
simulate429AfterRequests = afterRequests;
simulate429ForAuthType = forAuthType;
fallbackOccurred = false; // Reset fallback state when simulation is re-enabled
resetRequestCounter();
}

View File

@ -0,0 +1,10 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { setSimulate429 } from './src/utils/testUtils.js';
// Disable 429 simulation globally for all tests
setSimulate429(false);

View File

@ -10,6 +10,7 @@ export default defineConfig({
test: { test: {
reporters: ['default', 'junit'], reporters: ['default', 'junit'],
silent: true, silent: true,
setupFiles: ['./test-setup.ts'],
outputFile: { outputFile: {
junit: 'junit.xml', junit: 'junit.xml',
}, },