[JUNE 25] Permanent failover to Flash model for OAuth users after persistent 429 errors (#1376)
Co-authored-by: Scott Densmore <scottdensmore@mac.com>
This commit is contained in:
parent
4bf18da2b0
commit
e356949d3f
|
@ -125,6 +125,7 @@ vi.mock('@gemini-cli/core', async (importOriginal) => {
|
||||||
getGeminiClient: vi.fn(() => ({})),
|
getGeminiClient: vi.fn(() => ({})),
|
||||||
getCheckpointingEnabled: vi.fn(() => opts.checkpointing ?? true),
|
getCheckpointingEnabled: vi.fn(() => opts.checkpointing ?? true),
|
||||||
getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']),
|
getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']),
|
||||||
|
setFlashFallbackHandler: vi.fn(),
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -115,6 +115,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
|
||||||
const [editorError, setEditorError] = useState<string | null>(null);
|
const [editorError, setEditorError] = useState<string | null>(null);
|
||||||
const [footerHeight, setFooterHeight] = useState<number>(0);
|
const [footerHeight, setFooterHeight] = useState<number>(0);
|
||||||
const [corgiMode, setCorgiMode] = useState(false);
|
const [corgiMode, setCorgiMode] = useState(false);
|
||||||
|
const [currentModel, setCurrentModel] = useState(config.getModel());
|
||||||
const [shellModeActive, setShellModeActive] = useState(false);
|
const [shellModeActive, setShellModeActive] = useState(false);
|
||||||
const [showErrorDetails, setShowErrorDetails] = useState<boolean>(false);
|
const [showErrorDetails, setShowErrorDetails] = useState<boolean>(false);
|
||||||
const [showToolDescriptions, setShowToolDescriptions] =
|
const [showToolDescriptions, setShowToolDescriptions] =
|
||||||
|
@ -214,6 +215,42 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
|
||||||
}
|
}
|
||||||
}, [config, addItem]);
|
}, [config, addItem]);
|
||||||
|
|
||||||
|
// Watch for model changes (e.g., from Flash fallback)
|
||||||
|
useEffect(() => {
|
||||||
|
const checkModelChange = () => {
|
||||||
|
const configModel = config.getModel();
|
||||||
|
if (configModel !== currentModel) {
|
||||||
|
setCurrentModel(configModel);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check immediately and then periodically
|
||||||
|
checkModelChange();
|
||||||
|
const interval = setInterval(checkModelChange, 1000); // Check every second
|
||||||
|
|
||||||
|
return () => clearInterval(interval);
|
||||||
|
}, [config, currentModel]);
|
||||||
|
|
||||||
|
// Set up Flash fallback handler
|
||||||
|
useEffect(() => {
|
||||||
|
const flashFallbackHandler = async (
|
||||||
|
currentModel: string,
|
||||||
|
fallbackModel: string,
|
||||||
|
): Promise<boolean> => {
|
||||||
|
// Add message to UI history
|
||||||
|
addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.INFO,
|
||||||
|
text: `⚡ Rate limiting detected. Automatically switching from ${currentModel} to ${fallbackModel} for faster responses for the remainder of this session.`,
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
return true; // Always accept the fallback
|
||||||
|
};
|
||||||
|
|
||||||
|
config.setFlashFallbackHandler(flashFallbackHandler);
|
||||||
|
}, [config, addItem]);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
handleSlashCommand,
|
handleSlashCommand,
|
||||||
slashCommands,
|
slashCommands,
|
||||||
|
@ -787,7 +824,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
<Footer
|
<Footer
|
||||||
model={config.getModel()}
|
model={currentModel}
|
||||||
targetDir={config.getTargetDir()}
|
targetDir={config.getTargetDir()}
|
||||||
debugMode={config.getDebugMode()}
|
debugMode={config.getDebugMode()}
|
||||||
branchName={branchName}
|
branchName={branchName}
|
||||||
|
|
|
@ -171,7 +171,7 @@ export const useSlashCommandProcessor = (
|
||||||
[addMessage],
|
[addMessage],
|
||||||
);
|
);
|
||||||
|
|
||||||
const savedChatTags = async function () {
|
const savedChatTags = useCallback(async () => {
|
||||||
const geminiDir = config?.getProjectTempDir();
|
const geminiDir = config?.getProjectTempDir();
|
||||||
if (!geminiDir) {
|
if (!geminiDir) {
|
||||||
return [];
|
return [];
|
||||||
|
@ -186,7 +186,7 @@ export const useSlashCommandProcessor = (
|
||||||
} catch (_err) {
|
} catch (_err) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
};
|
}, [config]);
|
||||||
|
|
||||||
const slashCommands: SlashCommand[] = useMemo(() => {
|
const slashCommands: SlashCommand[] = useMemo(() => {
|
||||||
const commands: SlashCommand[] = [
|
const commands: SlashCommand[] = [
|
||||||
|
@ -992,6 +992,7 @@ Add any other context about the problem here.
|
||||||
addMemoryAction,
|
addMemoryAction,
|
||||||
addMessage,
|
addMessage,
|
||||||
toggleCorgiMode,
|
toggleCorgiMode,
|
||||||
|
savedChatTags,
|
||||||
config,
|
config,
|
||||||
showToolDescriptions,
|
showToolDescriptions,
|
||||||
session,
|
session,
|
||||||
|
|
|
@ -35,7 +35,10 @@ import {
|
||||||
TelemetryTarget,
|
TelemetryTarget,
|
||||||
StartSessionEvent,
|
StartSessionEvent,
|
||||||
} from '../telemetry/index.js';
|
} from '../telemetry/index.js';
|
||||||
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from './models.js';
|
import {
|
||||||
|
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||||
|
DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
|
} from './models.js';
|
||||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||||
|
|
||||||
export enum ApprovalMode {
|
export enum ApprovalMode {
|
||||||
|
@ -85,6 +88,11 @@ export interface SandboxConfig {
|
||||||
image: string;
|
image: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type FlashFallbackHandler = (
|
||||||
|
currentModel: string,
|
||||||
|
fallbackModel: string,
|
||||||
|
) => Promise<boolean>;
|
||||||
|
|
||||||
export interface ConfigParameters {
|
export interface ConfigParameters {
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
embeddingModel?: string;
|
embeddingModel?: string;
|
||||||
|
@ -156,6 +164,8 @@ export class Config {
|
||||||
private readonly bugCommand: BugCommandSettings | undefined;
|
private readonly bugCommand: BugCommandSettings | undefined;
|
||||||
private readonly model: string;
|
private readonly model: string;
|
||||||
private readonly extensionContextFilePaths: string[];
|
private readonly extensionContextFilePaths: string[];
|
||||||
|
private modelSwitchedDuringSession: boolean = false;
|
||||||
|
flashFallbackHandler?: FlashFallbackHandler;
|
||||||
|
|
||||||
constructor(params: ConfigParameters) {
|
constructor(params: ConfigParameters) {
|
||||||
this.sessionId = params.sessionId;
|
this.sessionId = params.sessionId;
|
||||||
|
@ -216,9 +226,24 @@ export class Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
async refreshAuth(authMethod: AuthType) {
|
async refreshAuth(authMethod: AuthType) {
|
||||||
|
// Check if this is actually a switch to a different auth method
|
||||||
|
const previousAuthType = this.contentGeneratorConfig?.authType;
|
||||||
|
const _isAuthMethodSwitch =
|
||||||
|
previousAuthType && previousAuthType !== authMethod;
|
||||||
|
|
||||||
|
// Always use the original default model when switching auth methods
|
||||||
|
// This ensures users don't stay on Flash after switching between auth types
|
||||||
|
// and allows API key users to get proper fallback behavior from getEffectiveModel
|
||||||
|
const modelToUse = this.model; // Use the original default model
|
||||||
|
|
||||||
|
// Temporarily clear contentGeneratorConfig to prevent getModel() from returning
|
||||||
|
// the previous session's model (which might be Flash)
|
||||||
|
this.contentGeneratorConfig = undefined!;
|
||||||
|
|
||||||
const contentConfig = await createContentGeneratorConfig(
|
const contentConfig = await createContentGeneratorConfig(
|
||||||
this.getModel(),
|
modelToUse,
|
||||||
authMethod,
|
authMethod,
|
||||||
|
this,
|
||||||
);
|
);
|
||||||
|
|
||||||
const gc = new GeminiClient(this);
|
const gc = new GeminiClient(this);
|
||||||
|
@ -226,6 +251,11 @@ export class Config {
|
||||||
this.toolRegistry = await createToolRegistry(this);
|
this.toolRegistry = await createToolRegistry(this);
|
||||||
await gc.initialize(contentConfig);
|
await gc.initialize(contentConfig);
|
||||||
this.contentGeneratorConfig = contentConfig;
|
this.contentGeneratorConfig = contentConfig;
|
||||||
|
|
||||||
|
// Reset the session flag since we're explicitly changing auth and using default model
|
||||||
|
this.modelSwitchedDuringSession = false;
|
||||||
|
|
||||||
|
// Note: In the future, we may want to reset any cached state when switching auth methods
|
||||||
}
|
}
|
||||||
|
|
||||||
getSessionId(): string {
|
getSessionId(): string {
|
||||||
|
@ -240,6 +270,28 @@ export class Config {
|
||||||
return this.contentGeneratorConfig?.model || this.model;
|
return this.contentGeneratorConfig?.model || this.model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setModel(newModel: string): void {
|
||||||
|
if (this.contentGeneratorConfig) {
|
||||||
|
this.contentGeneratorConfig.model = newModel;
|
||||||
|
this.modelSwitchedDuringSession = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isModelSwitchedDuringSession(): boolean {
|
||||||
|
return this.modelSwitchedDuringSession;
|
||||||
|
}
|
||||||
|
|
||||||
|
resetModelToDefault(): void {
|
||||||
|
if (this.contentGeneratorConfig) {
|
||||||
|
this.contentGeneratorConfig.model = this.model; // Reset to the original default model
|
||||||
|
this.modelSwitchedDuringSession = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
setFlashFallbackHandler(handler: FlashFallbackHandler): void {
|
||||||
|
this.flashFallbackHandler = handler;
|
||||||
|
}
|
||||||
|
|
||||||
getEmbeddingModel(): string {
|
getEmbeddingModel(): string {
|
||||||
return this.embeddingModel;
|
return this.embeddingModel;
|
||||||
}
|
}
|
||||||
|
@ -445,3 +497,6 @@ export function createToolRegistry(config: Config): Promise<ToolRegistry> {
|
||||||
return registry;
|
return registry;
|
||||||
})();
|
})();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Export model constants for use in CLI
|
||||||
|
export { DEFAULT_GEMINI_FLASH_MODEL };
|
||||||
|
|
|
@ -0,0 +1,139 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, beforeEach } from 'vitest';
|
||||||
|
import { Config } from './config.js';
|
||||||
|
import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL } from './models.js';
|
||||||
|
|
||||||
|
describe('Flash Model Fallback Configuration', () => {
|
||||||
|
let config: Config;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
config = new Config({
|
||||||
|
sessionId: 'test-session',
|
||||||
|
targetDir: '/test',
|
||||||
|
debugMode: false,
|
||||||
|
cwd: '/test',
|
||||||
|
model: DEFAULT_GEMINI_MODEL,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Initialize contentGeneratorConfig for testing
|
||||||
|
(
|
||||||
|
config as unknown as { contentGeneratorConfig: unknown }
|
||||||
|
).contentGeneratorConfig = {
|
||||||
|
model: DEFAULT_GEMINI_MODEL,
|
||||||
|
authType: 'oauth-personal',
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('setModel', () => {
|
||||||
|
it('should update the model and mark as switched during session', () => {
|
||||||
|
expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL);
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(false);
|
||||||
|
|
||||||
|
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
|
||||||
|
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle multiple model switches during session', () => {
|
||||||
|
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(true);
|
||||||
|
|
||||||
|
config.setModel('gemini-1.5-pro');
|
||||||
|
expect(config.getModel()).toBe('gemini-1.5-pro');
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should only mark as switched if contentGeneratorConfig exists', () => {
|
||||||
|
// Create config without initializing contentGeneratorConfig
|
||||||
|
const newConfig = new Config({
|
||||||
|
sessionId: 'test-session-2',
|
||||||
|
targetDir: '/test',
|
||||||
|
debugMode: false,
|
||||||
|
cwd: '/test',
|
||||||
|
model: DEFAULT_GEMINI_MODEL,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should not crash when contentGeneratorConfig is undefined
|
||||||
|
newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
expect(newConfig.isModelSwitchedDuringSession()).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getModel', () => {
|
||||||
|
it('should return contentGeneratorConfig model if available', () => {
|
||||||
|
// Simulate initialized content generator config
|
||||||
|
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should fallback to initial model if contentGeneratorConfig is not available', () => {
|
||||||
|
// Test with fresh config where contentGeneratorConfig might not be set
|
||||||
|
const newConfig = new Config({
|
||||||
|
sessionId: 'test-session-2',
|
||||||
|
targetDir: '/test',
|
||||||
|
debugMode: false,
|
||||||
|
cwd: '/test',
|
||||||
|
model: 'custom-model',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(newConfig.getModel()).toBe('custom-model');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('isModelSwitchedDuringSession', () => {
|
||||||
|
it('should start as false for new session', () => {
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should remain false if no model switch occurs', () => {
|
||||||
|
// Perform other operations that don't involve model switching
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should persist switched state throughout session', () => {
|
||||||
|
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(true);
|
||||||
|
|
||||||
|
// Should remain true even after getting model
|
||||||
|
config.getModel();
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('resetModelToDefault', () => {
|
||||||
|
it('should reset model to default and clear session switch flag', () => {
|
||||||
|
// Switch to Flash first
|
||||||
|
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
expect(config.getModel()).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(true);
|
||||||
|
|
||||||
|
// Reset to default
|
||||||
|
config.resetModelToDefault();
|
||||||
|
|
||||||
|
// Should be back to default with flag cleared
|
||||||
|
expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL);
|
||||||
|
expect(config.isModelSwitchedDuringSession()).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle case where contentGeneratorConfig is not initialized', () => {
|
||||||
|
// Create config without initializing contentGeneratorConfig
|
||||||
|
const newConfig = new Config({
|
||||||
|
sessionId: 'test-session-2',
|
||||||
|
targetDir: '/test',
|
||||||
|
debugMode: false,
|
||||||
|
cwd: '/test',
|
||||||
|
model: DEFAULT_GEMINI_MODEL,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should not crash when contentGeneratorConfig is undefined
|
||||||
|
expect(() => newConfig.resetModelToDefault()).not.toThrow();
|
||||||
|
expect(newConfig.isModelSwitchedDuringSession()).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
|
@ -20,6 +20,7 @@ import { Turn } from './turn.js';
|
||||||
import { getCoreSystemPrompt } from './prompts.js';
|
import { getCoreSystemPrompt } from './prompts.js';
|
||||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||||
|
import { setSimulate429 } from '../utils/testUtils.js';
|
||||||
|
|
||||||
// --- Mocks ---
|
// --- Mocks ---
|
||||||
const mockChatCreateFn = vi.fn();
|
const mockChatCreateFn = vi.fn();
|
||||||
|
@ -68,6 +69,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
vi.resetAllMocks();
|
vi.resetAllMocks();
|
||||||
|
|
||||||
|
// Disable 429 simulation for tests
|
||||||
|
setSimulate429(false);
|
||||||
|
|
||||||
// Set up the mock for GoogleGenAI constructor and its methods
|
// Set up the mock for GoogleGenAI constructor and its methods
|
||||||
const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
|
const MockedGoogleGenAI = vi.mocked(GoogleGenAI);
|
||||||
MockedGoogleGenAI.mockImplementation(() => {
|
MockedGoogleGenAI.mockImplementation(() => {
|
||||||
|
|
|
@ -38,6 +38,7 @@ import {
|
||||||
} from './contentGenerator.js';
|
} from './contentGenerator.js';
|
||||||
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||||
|
import { AuthType } from './contentGenerator.js';
|
||||||
|
|
||||||
function isThinkingSupported(model: string) {
|
function isThinkingSupported(model: string) {
|
||||||
if (model.startsWith('gemini-2.5')) return true;
|
if (model.startsWith('gemini-2.5')) return true;
|
||||||
|
@ -276,7 +277,11 @@ export class GeminiClient {
|
||||||
contents,
|
contents,
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await retryWithBackoff(apiCall);
|
const result = await retryWithBackoff(apiCall, {
|
||||||
|
onPersistent429: async (authType?: string) =>
|
||||||
|
await this.handleFlashFallback(authType),
|
||||||
|
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||||
|
});
|
||||||
|
|
||||||
const text = getResponseText(result);
|
const text = getResponseText(result);
|
||||||
if (!text) {
|
if (!text) {
|
||||||
|
@ -360,7 +365,11 @@ export class GeminiClient {
|
||||||
contents,
|
contents,
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await retryWithBackoff(apiCall);
|
const result = await retryWithBackoff(apiCall, {
|
||||||
|
onPersistent429: async (authType?: string) =>
|
||||||
|
await this.handleFlashFallback(authType),
|
||||||
|
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||||
|
});
|
||||||
return result;
|
return result;
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
if (abortSignal.aborted) {
|
if (abortSignal.aborted) {
|
||||||
|
@ -489,4 +498,43 @@ export class GeminiClient {
|
||||||
}
|
}
|
||||||
: null;
|
: null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles fallback to Flash model when persistent 429 errors occur for OAuth users.
|
||||||
|
* Uses a fallback handler if provided by the config, otherwise returns null.
|
||||||
|
*/
|
||||||
|
private async handleFlashFallback(authType?: string): Promise<string | null> {
|
||||||
|
// Only handle fallback for OAuth users
|
||||||
|
if (
|
||||||
|
authType !== AuthType.LOGIN_WITH_GOOGLE_PERSONAL &&
|
||||||
|
authType !== AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE
|
||||||
|
) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentModel = this.model;
|
||||||
|
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
|
||||||
|
// Don't fallback if already using Flash model
|
||||||
|
if (currentModel === fallbackModel) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if config has a fallback handler (set by CLI package)
|
||||||
|
const fallbackHandler = this.config.flashFallbackHandler;
|
||||||
|
if (typeof fallbackHandler === 'function') {
|
||||||
|
try {
|
||||||
|
const accepted = await fallbackHandler(currentModel, fallbackModel);
|
||||||
|
if (accepted) {
|
||||||
|
this.config.setModel(fallbackModel);
|
||||||
|
this.model = fallbackModel;
|
||||||
|
return fallbackModel;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn('Flash fallback handler failed:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,14 +51,18 @@ export type ContentGeneratorConfig = {
|
||||||
export async function createContentGeneratorConfig(
|
export async function createContentGeneratorConfig(
|
||||||
model: string | undefined,
|
model: string | undefined,
|
||||||
authType: AuthType | undefined,
|
authType: AuthType | undefined,
|
||||||
|
config?: { getModel?: () => string },
|
||||||
): Promise<ContentGeneratorConfig> {
|
): Promise<ContentGeneratorConfig> {
|
||||||
const geminiApiKey = process.env.GEMINI_API_KEY;
|
const geminiApiKey = process.env.GEMINI_API_KEY;
|
||||||
const googleApiKey = process.env.GOOGLE_API_KEY;
|
const googleApiKey = process.env.GOOGLE_API_KEY;
|
||||||
const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT;
|
const googleCloudProject = process.env.GOOGLE_CLOUD_PROJECT;
|
||||||
const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION;
|
const googleCloudLocation = process.env.GOOGLE_CLOUD_LOCATION;
|
||||||
|
|
||||||
|
// Use runtime model from config if available, otherwise fallback to parameter or default
|
||||||
|
const effectiveModel = config?.getModel?.() || model || DEFAULT_GEMINI_MODEL;
|
||||||
|
|
||||||
const contentGeneratorConfig: ContentGeneratorConfig = {
|
const contentGeneratorConfig: ContentGeneratorConfig = {
|
||||||
model: model || DEFAULT_GEMINI_MODEL,
|
model: effectiveModel,
|
||||||
authType,
|
authType,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ import {
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { GeminiChat } from './geminiChat.js';
|
import { GeminiChat } from './geminiChat.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
|
import { setSimulate429 } from '../utils/testUtils.js';
|
||||||
|
|
||||||
// Mocks
|
// Mocks
|
||||||
const mockModelsModule = {
|
const mockModelsModule = {
|
||||||
|
@ -29,6 +30,12 @@ const mockConfig = {
|
||||||
getTelemetryLogPromptsEnabled: () => true,
|
getTelemetryLogPromptsEnabled: () => true,
|
||||||
getUsageStatisticsEnabled: () => true,
|
getUsageStatisticsEnabled: () => true,
|
||||||
getDebugMode: () => false,
|
getDebugMode: () => false,
|
||||||
|
getContentGeneratorConfig: () => ({
|
||||||
|
authType: 'oauth-personal',
|
||||||
|
model: 'test-model',
|
||||||
|
}),
|
||||||
|
setModel: vi.fn(),
|
||||||
|
flashFallbackHandler: undefined,
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
|
|
||||||
describe('GeminiChat', () => {
|
describe('GeminiChat', () => {
|
||||||
|
@ -38,6 +45,8 @@ describe('GeminiChat', () => {
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
|
// Disable 429 simulation for tests
|
||||||
|
setSimulate429(false);
|
||||||
// Reset history for each test by creating a new instance
|
// Reset history for each test by creating a new instance
|
||||||
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
|
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
|
||||||
});
|
});
|
||||||
|
|
|
@ -18,7 +18,7 @@ import {
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { retryWithBackoff } from '../utils/retry.js';
|
import { retryWithBackoff } from '../utils/retry.js';
|
||||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||||
import { ContentGenerator } from './contentGenerator.js';
|
import { ContentGenerator, AuthType } from './contentGenerator.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import {
|
import {
|
||||||
logApiRequest,
|
logApiRequest,
|
||||||
|
@ -34,6 +34,7 @@ import {
|
||||||
ApiRequestEvent,
|
ApiRequestEvent,
|
||||||
ApiResponseEvent,
|
ApiResponseEvent,
|
||||||
} from '../telemetry/types.js';
|
} from '../telemetry/types.js';
|
||||||
|
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if the response is valid, false otherwise.
|
* Returns true if the response is valid, false otherwise.
|
||||||
|
@ -181,6 +182,44 @@ export class GeminiChat {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handles fallback to Flash model when persistent 429 errors occur for OAuth users.
|
||||||
|
* Uses a fallback handler if provided by the config, otherwise returns null.
|
||||||
|
*/
|
||||||
|
private async handleFlashFallback(authType?: string): Promise<string | null> {
|
||||||
|
// Only handle fallback for OAuth users
|
||||||
|
if (
|
||||||
|
authType !== AuthType.LOGIN_WITH_GOOGLE_PERSONAL &&
|
||||||
|
authType !== AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE
|
||||||
|
) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentModel = this.model;
|
||||||
|
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
|
||||||
|
// Don't fallback if already using Flash model
|
||||||
|
if (currentModel === fallbackModel) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if config has a fallback handler (set by CLI package)
|
||||||
|
const fallbackHandler = this.config.flashFallbackHandler;
|
||||||
|
if (typeof fallbackHandler === 'function') {
|
||||||
|
try {
|
||||||
|
const accepted = await fallbackHandler(currentModel, fallbackModel);
|
||||||
|
if (accepted) {
|
||||||
|
this.config.setModel(fallbackModel);
|
||||||
|
return fallbackModel;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn('Flash fallback handler failed:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends a message to the model and returns the response.
|
* Sends a message to the model and returns the response.
|
||||||
*
|
*
|
||||||
|
@ -315,6 +354,9 @@ export class GeminiChat {
|
||||||
}
|
}
|
||||||
return false; // Don't retry other errors by default
|
return false; // Don't retry other errors by default
|
||||||
},
|
},
|
||||||
|
onPersistent429: async (authType?: string) =>
|
||||||
|
await this.handleFlashFallback(authType),
|
||||||
|
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Resolve the internal tracking of send completion promise - `sendPromise`
|
// Resolve the internal tracking of send completion promise - `sendPromise`
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||||
|
import { Config } from '../config/config.js';
|
||||||
|
import {
|
||||||
|
setSimulate429,
|
||||||
|
disableSimulationAfterFallback,
|
||||||
|
shouldSimulate429,
|
||||||
|
createSimulated429Error,
|
||||||
|
resetRequestCounter,
|
||||||
|
} from './testUtils.js';
|
||||||
|
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||||
|
import { retryWithBackoff } from './retry.js';
|
||||||
|
import { AuthType } from '../core/contentGenerator.js';
|
||||||
|
|
||||||
|
describe('Flash Fallback Integration', () => {
|
||||||
|
let config: Config;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
config = new Config({
|
||||||
|
sessionId: 'test-session',
|
||||||
|
targetDir: '/test',
|
||||||
|
debugMode: false,
|
||||||
|
cwd: '/test',
|
||||||
|
model: 'gemini-2.5-pro',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Reset simulation state for each test
|
||||||
|
setSimulate429(false);
|
||||||
|
resetRequestCounter();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should automatically accept fallback', async () => {
|
||||||
|
// Set up a minimal flash fallback handler for testing
|
||||||
|
const flashFallbackHandler = async (): Promise<boolean> => true;
|
||||||
|
|
||||||
|
config.setFlashFallbackHandler(flashFallbackHandler);
|
||||||
|
|
||||||
|
// Call the handler directly to test
|
||||||
|
const result = await config.flashFallbackHandler!(
|
||||||
|
'gemini-2.5-pro',
|
||||||
|
DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify it automatically accepts
|
||||||
|
expect(result).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should trigger fallback after 3 consecutive 429 errors for OAuth users', async () => {
|
||||||
|
let fallbackCalled = false;
|
||||||
|
let fallbackModel = '';
|
||||||
|
|
||||||
|
// Mock function that simulates exactly 3 429 errors, then succeeds after fallback
|
||||||
|
const mockApiCall = vi
|
||||||
|
.fn()
|
||||||
|
.mockRejectedValueOnce(createSimulated429Error())
|
||||||
|
.mockRejectedValueOnce(createSimulated429Error())
|
||||||
|
.mockRejectedValueOnce(createSimulated429Error())
|
||||||
|
.mockResolvedValueOnce('success after fallback');
|
||||||
|
|
||||||
|
// Mock fallback handler
|
||||||
|
const mockFallbackHandler = vi.fn(async (_authType?: string) => {
|
||||||
|
fallbackCalled = true;
|
||||||
|
fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
return fallbackModel;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Test with OAuth personal auth type, with maxAttempts = 3 to ensure fallback triggers
|
||||||
|
const result = await retryWithBackoff(mockApiCall, {
|
||||||
|
maxAttempts: 3,
|
||||||
|
initialDelayMs: 1,
|
||||||
|
maxDelayMs: 10,
|
||||||
|
shouldRetry: (error: Error) => {
|
||||||
|
const status = (error as Error & { status?: number }).status;
|
||||||
|
return status === 429;
|
||||||
|
},
|
||||||
|
onPersistent429: mockFallbackHandler,
|
||||||
|
authType: AuthType.LOGIN_WITH_GOOGLE_PERSONAL,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify fallback was triggered
|
||||||
|
expect(fallbackCalled).toBe(true);
|
||||||
|
expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||||
|
expect(mockFallbackHandler).toHaveBeenCalledWith(
|
||||||
|
AuthType.LOGIN_WITH_GOOGLE_PERSONAL,
|
||||||
|
);
|
||||||
|
expect(result).toBe('success after fallback');
|
||||||
|
// Should have: 3 failures, then fallback triggered, then 1 success after retry reset
|
||||||
|
expect(mockApiCall).toHaveBeenCalledTimes(4);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not trigger fallback for API key users', async () => {
|
||||||
|
let fallbackCalled = false;
|
||||||
|
|
||||||
|
// Mock function that simulates 429 errors
|
||||||
|
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
|
||||||
|
|
||||||
|
// Mock fallback handler
|
||||||
|
const mockFallbackHandler = vi.fn(async () => {
|
||||||
|
fallbackCalled = true;
|
||||||
|
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Test with API key auth type - should not trigger fallback
|
||||||
|
try {
|
||||||
|
await retryWithBackoff(mockApiCall, {
|
||||||
|
maxAttempts: 5,
|
||||||
|
initialDelayMs: 10,
|
||||||
|
maxDelayMs: 100,
|
||||||
|
shouldRetry: (error: Error) => {
|
||||||
|
const status = (error as Error & { status?: number }).status;
|
||||||
|
return status === 429;
|
||||||
|
},
|
||||||
|
onPersistent429: mockFallbackHandler,
|
||||||
|
authType: AuthType.USE_GEMINI, // API key auth type
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
// Expected to throw after max attempts
|
||||||
|
expect((error as Error).message).toContain('Rate limit exceeded');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fallback was NOT triggered for API key users
|
||||||
|
expect(fallbackCalled).toBe(false);
|
||||||
|
expect(mockFallbackHandler).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should properly disable simulation state after fallback', () => {
|
||||||
|
// Enable simulation
|
||||||
|
setSimulate429(true);
|
||||||
|
|
||||||
|
// Verify simulation is enabled
|
||||||
|
expect(shouldSimulate429()).toBe(true);
|
||||||
|
|
||||||
|
// Disable simulation after fallback
|
||||||
|
disableSimulationAfterFallback();
|
||||||
|
|
||||||
|
// Verify simulation is now disabled
|
||||||
|
expect(shouldSimulate429()).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
|
@ -7,6 +7,7 @@
|
||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||||
import { retryWithBackoff } from './retry.js';
|
import { retryWithBackoff } from './retry.js';
|
||||||
|
import { setSimulate429 } from './testUtils.js';
|
||||||
|
|
||||||
// Define an interface for the error with a status property
|
// Define an interface for the error with a status property
|
||||||
interface HttpError extends Error {
|
interface HttpError extends Error {
|
||||||
|
@ -42,10 +43,15 @@ class NonRetryableError extends Error {
|
||||||
describe('retryWithBackoff', () => {
|
describe('retryWithBackoff', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.useFakeTimers();
|
vi.useFakeTimers();
|
||||||
|
// Disable 429 simulation for tests
|
||||||
|
setSimulate429(false);
|
||||||
|
// Suppress unhandled promise rejection warnings for tests that expect errors
|
||||||
|
console.warn = vi.fn();
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
vi.restoreAllMocks();
|
vi.restoreAllMocks();
|
||||||
|
vi.useRealTimers();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return the result on the first attempt if successful', async () => {
|
it('should return the result on the first attempt if successful', async () => {
|
||||||
|
@ -231,4 +237,197 @@ describe('retryWithBackoff', () => {
|
||||||
expect(d).toBeLessThanOrEqual(100 * 1.3);
|
expect(d).toBeLessThanOrEqual(100 * 1.3);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Flash model fallback for OAuth users', () => {
|
||||||
|
it('should trigger fallback for OAuth personal users after persistent 429 errors', async () => {
|
||||||
|
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
|
||||||
|
|
||||||
|
let fallbackOccurred = false;
|
||||||
|
const mockFn = vi.fn().mockImplementation(async () => {
|
||||||
|
if (!fallbackOccurred) {
|
||||||
|
const error: HttpError = new Error('Rate limit exceeded');
|
||||||
|
error.status = 429;
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
return 'success';
|
||||||
|
});
|
||||||
|
|
||||||
|
const promise = retryWithBackoff(mockFn, {
|
||||||
|
maxAttempts: 3,
|
||||||
|
initialDelayMs: 100,
|
||||||
|
onPersistent429: async (authType?: string) => {
|
||||||
|
fallbackOccurred = true;
|
||||||
|
return await fallbackCallback(authType);
|
||||||
|
},
|
||||||
|
authType: 'oauth-personal',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Advance all timers to complete retries
|
||||||
|
await vi.runAllTimersAsync();
|
||||||
|
|
||||||
|
// Should succeed after fallback
|
||||||
|
await expect(promise).resolves.toBe('success');
|
||||||
|
|
||||||
|
// Verify callback was called with correct auth type
|
||||||
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||||
|
|
||||||
|
// Should retry again after fallback
|
||||||
|
expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should trigger fallback for OAuth enterprise users after persistent 429 errors', async () => {
|
||||||
|
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
|
||||||
|
|
||||||
|
let fallbackOccurred = false;
|
||||||
|
const mockFn = vi.fn().mockImplementation(async () => {
|
||||||
|
if (!fallbackOccurred) {
|
||||||
|
const error: HttpError = new Error('Rate limit exceeded');
|
||||||
|
error.status = 429;
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
return 'success';
|
||||||
|
});
|
||||||
|
|
||||||
|
const promise = retryWithBackoff(mockFn, {
|
||||||
|
maxAttempts: 3,
|
||||||
|
initialDelayMs: 100,
|
||||||
|
onPersistent429: async (authType?: string) => {
|
||||||
|
fallbackOccurred = true;
|
||||||
|
return await fallbackCallback(authType);
|
||||||
|
},
|
||||||
|
authType: 'oauth-enterprise',
|
||||||
|
});
|
||||||
|
|
||||||
|
await vi.runAllTimersAsync();
|
||||||
|
|
||||||
|
await expect(promise).resolves.toBe('success');
|
||||||
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-enterprise');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should NOT trigger fallback for API key users', async () => {
|
||||||
|
const fallbackCallback = vi.fn();
|
||||||
|
|
||||||
|
const mockFn = vi.fn(async () => {
|
||||||
|
const error: HttpError = new Error('Rate limit exceeded');
|
||||||
|
error.status = 429;
|
||||||
|
throw error;
|
||||||
|
});
|
||||||
|
|
||||||
|
const promise = retryWithBackoff(mockFn, {
|
||||||
|
maxAttempts: 3,
|
||||||
|
initialDelayMs: 100,
|
||||||
|
onPersistent429: fallbackCallback,
|
||||||
|
authType: 'gemini-api-key',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Handle the promise properly to avoid unhandled rejections
|
||||||
|
const resultPromise = promise.catch((error) => error);
|
||||||
|
await vi.runAllTimersAsync();
|
||||||
|
const result = await resultPromise;
|
||||||
|
|
||||||
|
// Should fail after all retries without fallback
|
||||||
|
expect(result).toBeInstanceOf(Error);
|
||||||
|
expect(result.message).toBe('Rate limit exceeded');
|
||||||
|
|
||||||
|
// Callback should not be called for API key users
|
||||||
|
expect(fallbackCallback).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reset attempt counter and continue after successful fallback', async () => {
|
||||||
|
let fallbackCalled = false;
|
||||||
|
const fallbackCallback = vi.fn().mockImplementation(async () => {
|
||||||
|
fallbackCalled = true;
|
||||||
|
return 'gemini-2.5-flash';
|
||||||
|
});
|
||||||
|
|
||||||
|
const mockFn = vi.fn().mockImplementation(async () => {
|
||||||
|
if (!fallbackCalled) {
|
||||||
|
const error: HttpError = new Error('Rate limit exceeded');
|
||||||
|
error.status = 429;
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
return 'success';
|
||||||
|
});
|
||||||
|
|
||||||
|
const promise = retryWithBackoff(mockFn, {
|
||||||
|
maxAttempts: 3,
|
||||||
|
initialDelayMs: 100,
|
||||||
|
onPersistent429: fallbackCallback,
|
||||||
|
authType: 'oauth-personal',
|
||||||
|
});
|
||||||
|
|
||||||
|
await vi.runAllTimersAsync();
|
||||||
|
|
||||||
|
await expect(promise).resolves.toBe('success');
|
||||||
|
expect(fallbackCallback).toHaveBeenCalledOnce();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should continue with original error if fallback is rejected', async () => {
|
||||||
|
const fallbackCallback = vi.fn().mockResolvedValue(null); // User rejected fallback
|
||||||
|
|
||||||
|
const mockFn = vi.fn(async () => {
|
||||||
|
const error: HttpError = new Error('Rate limit exceeded');
|
||||||
|
error.status = 429;
|
||||||
|
throw error;
|
||||||
|
});
|
||||||
|
|
||||||
|
const promise = retryWithBackoff(mockFn, {
|
||||||
|
maxAttempts: 3,
|
||||||
|
initialDelayMs: 100,
|
||||||
|
onPersistent429: fallbackCallback,
|
||||||
|
authType: 'oauth-personal',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Handle the promise properly to avoid unhandled rejections
|
||||||
|
const resultPromise = promise.catch((error) => error);
|
||||||
|
await vi.runAllTimersAsync();
|
||||||
|
const result = await resultPromise;
|
||||||
|
|
||||||
|
// Should fail with original error when fallback is rejected
|
||||||
|
expect(result).toBeInstanceOf(Error);
|
||||||
|
expect(result.message).toBe('Rate limit exceeded');
|
||||||
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle mixed error types (only count consecutive 429s)', async () => {
|
||||||
|
const fallbackCallback = vi.fn().mockResolvedValue('gemini-2.5-flash');
|
||||||
|
let attempts = 0;
|
||||||
|
let fallbackOccurred = false;
|
||||||
|
|
||||||
|
const mockFn = vi.fn().mockImplementation(async () => {
|
||||||
|
attempts++;
|
||||||
|
if (fallbackOccurred) {
|
||||||
|
return 'success';
|
||||||
|
}
|
||||||
|
if (attempts === 1) {
|
||||||
|
// First attempt: 500 error (resets consecutive count)
|
||||||
|
const error: HttpError = new Error('Server error');
|
||||||
|
error.status = 500;
|
||||||
|
throw error;
|
||||||
|
} else {
|
||||||
|
// Remaining attempts: 429 errors
|
||||||
|
const error: HttpError = new Error('Rate limit exceeded');
|
||||||
|
error.status = 429;
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const promise = retryWithBackoff(mockFn, {
|
||||||
|
maxAttempts: 5,
|
||||||
|
initialDelayMs: 100,
|
||||||
|
onPersistent429: async (authType?: string) => {
|
||||||
|
fallbackOccurred = true;
|
||||||
|
return await fallbackCallback(authType);
|
||||||
|
},
|
||||||
|
authType: 'oauth-personal',
|
||||||
|
});
|
||||||
|
|
||||||
|
await vi.runAllTimersAsync();
|
||||||
|
|
||||||
|
await expect(promise).resolves.toBe('success');
|
||||||
|
|
||||||
|
// Should trigger fallback after 4 consecutive 429s (attempts 2-5)
|
||||||
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -4,11 +4,15 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import { AuthType } from '../core/contentGenerator.js';
|
||||||
|
|
||||||
export interface RetryOptions {
|
export interface RetryOptions {
|
||||||
maxAttempts: number;
|
maxAttempts: number;
|
||||||
initialDelayMs: number;
|
initialDelayMs: number;
|
||||||
maxDelayMs: number;
|
maxDelayMs: number;
|
||||||
shouldRetry: (error: Error) => boolean;
|
shouldRetry: (error: Error) => boolean;
|
||||||
|
onPersistent429?: (authType?: string) => Promise<string | null>;
|
||||||
|
authType?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
||||||
|
@ -59,29 +63,69 @@ export async function retryWithBackoff<T>(
|
||||||
fn: () => Promise<T>,
|
fn: () => Promise<T>,
|
||||||
options?: Partial<RetryOptions>,
|
options?: Partial<RetryOptions>,
|
||||||
): Promise<T> {
|
): Promise<T> {
|
||||||
const { maxAttempts, initialDelayMs, maxDelayMs, shouldRetry } = {
|
const {
|
||||||
|
maxAttempts,
|
||||||
|
initialDelayMs,
|
||||||
|
maxDelayMs,
|
||||||
|
shouldRetry,
|
||||||
|
onPersistent429,
|
||||||
|
authType,
|
||||||
|
} = {
|
||||||
...DEFAULT_RETRY_OPTIONS,
|
...DEFAULT_RETRY_OPTIONS,
|
||||||
...options,
|
...options,
|
||||||
};
|
};
|
||||||
|
|
||||||
let attempt = 0;
|
let attempt = 0;
|
||||||
let currentDelay = initialDelayMs;
|
let currentDelay = initialDelayMs;
|
||||||
|
let consecutive429Count = 0;
|
||||||
|
|
||||||
while (attempt < maxAttempts) {
|
while (attempt < maxAttempts) {
|
||||||
attempt++;
|
attempt++;
|
||||||
try {
|
try {
|
||||||
return await fn();
|
return await fn();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
const errorStatus = getErrorStatus(error);
|
||||||
|
|
||||||
|
// Track consecutive 429 errors
|
||||||
|
if (errorStatus === 429) {
|
||||||
|
consecutive429Count++;
|
||||||
|
} else {
|
||||||
|
consecutive429Count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we've exhausted retries or shouldn't retry
|
||||||
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
|
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
|
||||||
|
// If we have persistent 429s and a fallback callback for OAuth
|
||||||
|
if (
|
||||||
|
consecutive429Count >= 3 &&
|
||||||
|
onPersistent429 &&
|
||||||
|
(authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL ||
|
||||||
|
authType === AuthType.LOGIN_WITH_GOOGLE_ENTERPRISE)
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
const fallbackModel = await onPersistent429(authType);
|
||||||
|
if (fallbackModel) {
|
||||||
|
// Reset attempt counter and try with new model
|
||||||
|
attempt = 0;
|
||||||
|
consecutive429Count = 0;
|
||||||
|
currentDelay = initialDelayMs;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} catch (fallbackError) {
|
||||||
|
// If fallback fails, continue with original error
|
||||||
|
console.warn('Fallback to Flash model failed:', fallbackError);
|
||||||
|
}
|
||||||
|
}
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { delayDurationMs, errorStatus } = getDelayDurationAndStatus(error);
|
const { delayDurationMs, errorStatus: delayErrorStatus } =
|
||||||
|
getDelayDurationAndStatus(error);
|
||||||
|
|
||||||
if (delayDurationMs > 0) {
|
if (delayDurationMs > 0) {
|
||||||
// Respect Retry-After header if present and parsed
|
// Respect Retry-After header if present and parsed
|
||||||
console.warn(
|
console.warn(
|
||||||
`Attempt ${attempt} failed with status ${errorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
|
`Attempt ${attempt} failed with status ${delayErrorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`,
|
||||||
error,
|
error,
|
||||||
);
|
);
|
||||||
await delay(delayDurationMs);
|
await delay(delayDurationMs);
|
||||||
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Testing utilities for simulating 429 errors in unit tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
let requestCounter = 0;
|
||||||
|
let simulate429Enabled = false;
|
||||||
|
let simulate429AfterRequests = 0;
|
||||||
|
let simulate429ForAuthType: string | undefined;
|
||||||
|
let fallbackOccurred = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if we should simulate a 429 error for the current request
|
||||||
|
*/
|
||||||
|
export function shouldSimulate429(authType?: string): boolean {
|
||||||
|
if (!simulate429Enabled || fallbackOccurred) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If auth type filter is set, only simulate for that auth type
|
||||||
|
if (simulate429ForAuthType && authType !== simulate429ForAuthType) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCounter++;
|
||||||
|
|
||||||
|
// If afterRequests is set, only simulate after that many requests
|
||||||
|
if (simulate429AfterRequests > 0) {
|
||||||
|
return requestCounter > simulate429AfterRequests;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, simulate for every request
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset the request counter (useful for tests)
|
||||||
|
*/
|
||||||
|
export function resetRequestCounter(): void {
|
||||||
|
requestCounter = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Disable 429 simulation after successful fallback
|
||||||
|
*/
|
||||||
|
export function disableSimulationAfterFallback(): void {
|
||||||
|
fallbackOccurred = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a simulated 429 error response
|
||||||
|
*/
|
||||||
|
export function createSimulated429Error(): Error {
|
||||||
|
const error = new Error('Rate limit exceeded (simulated)') as Error & {
|
||||||
|
status: number;
|
||||||
|
};
|
||||||
|
error.status = 429;
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reset simulation state when switching auth methods
|
||||||
|
*/
|
||||||
|
export function resetSimulationState(): void {
|
||||||
|
fallbackOccurred = false;
|
||||||
|
resetRequestCounter();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enable/disable 429 simulation programmatically (for tests)
|
||||||
|
*/
|
||||||
|
export function setSimulate429(
|
||||||
|
enabled: boolean,
|
||||||
|
afterRequests = 0,
|
||||||
|
forAuthType?: string,
|
||||||
|
): void {
|
||||||
|
simulate429Enabled = enabled;
|
||||||
|
simulate429AfterRequests = afterRequests;
|
||||||
|
simulate429ForAuthType = forAuthType;
|
||||||
|
fallbackOccurred = false; // Reset fallback state when simulation is re-enabled
|
||||||
|
resetRequestCounter();
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { setSimulate429 } from './src/utils/testUtils.js';
|
||||||
|
|
||||||
|
// Disable 429 simulation globally for all tests
|
||||||
|
setSimulate429(false);
|
|
@ -10,6 +10,7 @@ export default defineConfig({
|
||||||
test: {
|
test: {
|
||||||
reporters: ['default', 'junit'],
|
reporters: ['default', 'junit'],
|
||||||
silent: true,
|
silent: true,
|
||||||
|
setupFiles: ['./test-setup.ts'],
|
||||||
outputFile: {
|
outputFile: {
|
||||||
junit: 'junit.xml',
|
junit: 'junit.xml',
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in New Issue