From 7f20425c98d5adb5531e6c33ed92975b71b34c90 Mon Sep 17 00:00:00 2001 From: Allen Hutchison Date: Mon, 2 Jun 2025 13:55:54 -0700 Subject: [PATCH] feat(cli): add pro model availability check and fallback to flash (#608) --- packages/cli/src/config/config.test.ts | 16 +- packages/cli/src/config/config.ts | 46 ++++- packages/cli/src/config/settings.ts | 2 +- packages/cli/src/gemini.tsx | 18 +- packages/cli/src/ui/hooks/useGeminiStream.ts | 2 +- packages/cli/src/utils/modelCheck.test.ts | 179 +++++++++++++++++++ packages/cli/src/utils/modelCheck.ts | 75 ++++++++ packages/cli/tsconfig.json | 2 +- packages/core/src/index.ts | 1 + 9 files changed, 322 insertions(+), 19 deletions(-) create mode 100644 packages/cli/src/utils/modelCheck.test.ts create mode 100644 packages/cli/src/utils/modelCheck.ts diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 9f288372..a39278bc 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -82,29 +82,29 @@ describe('loadCliConfig', () => { it('should set showMemoryUsage to true when --memory flag is present', async () => { process.argv = ['node', 'script.js', '--show_memory_usage']; const settings: Settings = {}; - const config = await loadCliConfig(settings); - expect(config.getShowMemoryUsage()).toBe(true); + const result = await loadCliConfig(settings); + expect(result.config.getShowMemoryUsage()).toBe(true); }); it('should set showMemoryUsage to false when --memory flag is not present', async () => { process.argv = ['node', 'script.js']; const settings: Settings = {}; - const config = await loadCliConfig(settings); - expect(config.getShowMemoryUsage()).toBe(false); + const result = await loadCliConfig(settings); + expect(result.config.getShowMemoryUsage()).toBe(false); }); it('should set showMemoryUsage to false by default from settings if CLI flag is not present', async () => { process.argv = ['node', 'script.js']; const settings: Settings = { showMemoryUsage: false }; - const config = await loadCliConfig(settings); - expect(config.getShowMemoryUsage()).toBe(false); + const result = await loadCliConfig(settings); + expect(result.config.getShowMemoryUsage()).toBe(false); }); it('should prioritize CLI flag over settings for showMemoryUsage (CLI true, settings false)', async () => { process.argv = ['node', 'script.js', '--show_memory_usage']; const settings: Settings = { showMemoryUsage: false }; - const config = await loadCliConfig(settings); - expect(config.getShowMemoryUsage()).toBe(true); + const result = await loadCliConfig(settings); + expect(result.config.getShowMemoryUsage()).toBe(true); }); }); diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index ee1c9d36..2429ad64 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -19,6 +19,10 @@ import { } from '@gemini-code/core'; import { Settings } from './settings.js'; import { readPackageUp } from 'read-package-up'; +import { + getEffectiveModel, + type EffectiveModelCheckResult, +} from '../utils/modelCheck.js'; // Simple console logger for now - replace with actual logger if available const logger = { @@ -30,7 +34,8 @@ const logger = { error: (...args: any[]) => console.error('[ERROR]', ...args), }; -const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-05-06'; +export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro-preview-05-06'; +export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash-preview-05-20'; interface CliArgs { model: string | undefined; @@ -114,7 +119,16 @@ export async function loadHierarchicalGeminiMemory( return loadServerHierarchicalMemory(currentWorkingDirectory, debugMode); } -export async function loadCliConfig(settings: Settings): Promise { +export interface LoadCliConfigResult { + config: Config; + modelWasSwitched: boolean; + originalModelBeforeSwitch?: string; + finalModel: string; +} + +export async function loadCliConfig( + settings: Settings, +): Promise { loadEnvironment(); const geminiApiKey = process.env.GEMINI_API_KEY; @@ -164,9 +178,27 @@ export async function loadCliConfig(settings: Settings): Promise { const apiKeyForServer = geminiApiKey || googleApiKey || ''; const useVertexAI = hasGeminiApiKey ? false : undefined; + let modelToUse = argv.model || DEFAULT_GEMINI_MODEL; + let modelSwitched = false; + let originalModel: string | undefined = undefined; + + if (apiKeyForServer) { + const checkResult: EffectiveModelCheckResult = await getEffectiveModel( + apiKeyForServer, + modelToUse, + ); + if (checkResult.switched) { + modelSwitched = true; + originalModel = checkResult.originalModelIfSwitched; + modelToUse = checkResult.effectiveModel; + } + } else { + // logger.debug('API key not available during config load. Skipping model availability check.'); + } + const configParams: ConfigParameters = { apiKey: apiKeyForServer, - model: argv.model || DEFAULT_GEMINI_MODEL, + model: modelToUse, sandbox: argv.sandbox ?? settings.sandbox ?? argv.yolo ?? false, targetDir: process.cwd(), debugMode, @@ -186,7 +218,13 @@ export async function loadCliConfig(settings: Settings): Promise { argv.show_memory_usage || settings.showMemoryUsage || false, }; - return createServerConfig(configParams); + const config = createServerConfig(configParams); + return { + config, + modelWasSwitched: modelSwitched, + originalModelBeforeSwitch: originalModel, + finalModel: modelToUse, + }; } async function createUserAgent(): Promise { diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts index 5d51ba15..6c14a6dc 100644 --- a/packages/cli/src/config/settings.ts +++ b/packages/cli/src/config/settings.ts @@ -7,7 +7,7 @@ import * as fs from 'fs'; import * as path from 'path'; import { homedir } from 'os'; -import { MCPServerConfig } from '@gemini-code/core/src/config/config.js'; +import { MCPServerConfig } from '@gemini-code/core'; import stripJsonComments from 'strip-json-comments'; import { DefaultLight } from '../ui/themes/default-light.js'; import { DefaultDark } from '../ui/themes/default.js'; diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 07551813..f8cc77b6 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -50,11 +50,19 @@ async function main() { console.warn( 'GEMINI_CODE_SANDBOX_IMAGE is deprecated. Use GEMINI_SANDBOX_IMAGE_NAME instead.', ); - process.env.GEMINI_SANDBOX_IMAGE = process.env.GEMINI_CODE_SANDBOX_IMAGE; + process.env.GEMINI_SANDBOX_IMAGE_NAME = + process.env.GEMINI_CODE_SANDBOX_IMAGE; // Corrected to GEMINI_SANDBOX_IMAGE_NAME } const settings = loadSettings(process.cwd()); - const config = await loadCliConfig(settings.merged); + const { config, modelWasSwitched, originalModelBeforeSwitch, finalModel } = + await loadCliConfig(settings.merged); + + if (modelWasSwitched && originalModelBeforeSwitch) { + console.log( + `[INFO] Your configured model (${originalModelBeforeSwitch}) was temporarily unavailable. Switched to ${finalModel} for this session.`, + ); + } if (settings.merged.theme) { if (!themeManager.setActiveTheme(settings.merged.theme)) { @@ -128,8 +136,10 @@ async function main() { ...settings.merged, coreTools: nonInteractiveTools, }; - const nonInteractiveConfig = await loadCliConfig(nonInteractiveSettings); - await runNonInteractive(nonInteractiveConfig, input); + const nonInteractiveConfigResult = await loadCliConfig( + nonInteractiveSettings, + ); // Ensure config is reloaded with non-interactive tools + await runNonInteractive(nonInteractiveConfigResult.config, input); } // --- Global Unhandled Rejection Handler --- diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index b6ef1481..423f3489 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -17,6 +17,7 @@ import { Config, MessageSenderType, ToolCallRequestInfo, + GeminiChat, } from '@gemini-code/core'; import { type PartListUnion } from '@google/genai'; import { @@ -40,7 +41,6 @@ import { TrackedCompletedToolCall, TrackedCancelledToolCall, } from './useReactToolScheduler.js'; -import { GeminiChat } from '@gemini-code/core/src/core/geminiChat.js'; export function mergePartListUnions(list: PartListUnion[]): PartListUnion { const resultParts: PartListUnion = []; diff --git a/packages/cli/src/utils/modelCheck.test.ts b/packages/cli/src/utils/modelCheck.test.ts new file mode 100644 index 00000000..3b1cded8 --- /dev/null +++ b/packages/cli/src/utils/modelCheck.test.ts @@ -0,0 +1,179 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + getEffectiveModel, + type EffectiveModelCheckResult, +} from './modelCheck.js'; +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, +} from '../config/config.js'; + +// Mock global fetch +global.fetch = vi.fn(); + +// Mock AbortController +const mockAbort = vi.fn(); +global.AbortController = vi.fn(() => ({ + signal: { aborted: false }, // Start with not aborted + abort: mockAbort, + // eslint-disable-next-line @typescript-eslint/no-explicit-any +})) as any; + +describe('getEffectiveModel', () => { + const apiKey = 'test-api-key'; + + beforeEach(() => { + vi.useFakeTimers(); + vi.clearAllMocks(); + // Reset signal for each test if AbortController mock is more complex + global.AbortController = vi.fn(() => ({ + signal: { aborted: false }, + abort: mockAbort, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + })) as any; + }); + + afterEach(() => { + vi.restoreAllMocks(); + vi.useRealTimers(); + }); + + describe('when currentConfiguredModel is not DEFAULT_GEMINI_MODEL', () => { + it('should return the currentConfiguredModel and switched: false without fetching', async () => { + const customModel = 'custom-model-name'; + const result = await getEffectiveModel(apiKey, customModel); + expect(result).toEqual({ + effectiveModel: customModel, + switched: false, + }); + expect(fetch).not.toHaveBeenCalled(); + }); + }); + + describe('when currentConfiguredModel is DEFAULT_GEMINI_MODEL', () => { + it('should switch to DEFAULT_GEMINI_FLASH_MODEL if fetch returns 429', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: false, + status: 429, + }); + const result: EffectiveModelCheckResult = await getEffectiveModel( + apiKey, + DEFAULT_GEMINI_MODEL, + ); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_FLASH_MODEL, + switched: true, + originalModelIfSwitched: DEFAULT_GEMINI_MODEL, + }); + expect(fetch).toHaveBeenCalledTimes(1); + expect(fetch).toHaveBeenCalledWith( + `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${apiKey}`, + expect.any(Object), + ); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch returns 200', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + }); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch returns a non-429 error status (e.g., 500)', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ + ok: false, + status: 500, + }); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch throws a network error', async () => { + (fetch as vi.Mock).mockRejectedValueOnce(new Error('Network error')); + const result = await getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should return DEFAULT_GEMINI_MODEL if fetch times out', async () => { + // Simulate AbortController's signal changing and fetch throwing AbortError + const abortControllerInstance = { + signal: { aborted: false }, // mutable signal + abort: vi.fn(() => { + abortControllerInstance.signal.aborted = true; // Use abortControllerInstance + }), + }; + (global.AbortController as vi.Mock).mockImplementationOnce( + () => abortControllerInstance, + ); + + (fetch as vi.Mock).mockImplementationOnce( + async ({ signal }: { signal: AbortSignal }) => { + // Simulate the timeout advancing and abort being called + vi.advanceTimersByTime(2000); + if (signal.aborted) { + throw new DOMException('Aborted', 'AbortError'); + } + // Should not reach here in a timeout scenario + return { ok: true, status: 200 }; + }, + ); + + const resultPromise = getEffectiveModel(apiKey, DEFAULT_GEMINI_MODEL); + // Ensure timers are advanced to trigger the timeout within getEffectiveModel + await vi.advanceTimersToNextTimerAsync(); // Or advanceTimersByTime(2000) if more precise control is needed + + const result = await resultPromise; + + expect(mockAbort).toHaveBeenCalledTimes(0); // setTimeout calls controller.abort(), not our direct mockAbort + expect(abortControllerInstance.abort).toHaveBeenCalledTimes(1); + expect(result).toEqual({ + effectiveModel: DEFAULT_GEMINI_MODEL, + switched: false, + }); + expect(fetch).toHaveBeenCalledTimes(1); + }); + + it('should correctly pass API key and model in the fetch request', async () => { + (fetch as vi.Mock).mockResolvedValueOnce({ ok: true, status: 200 }); + const specificApiKey = 'specific-key-for-this-test'; + await getEffectiveModel(specificApiKey, DEFAULT_GEMINI_MODEL); + + expect(fetch).toHaveBeenCalledWith( + `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_MODEL}:generateContent?key=${specificApiKey}`, + expect.objectContaining({ + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + contents: [{ parts: [{ text: 'test' }] }], + generationConfig: { + maxOutputTokens: 1, + temperature: 0, + topK: 1, + thinkingConfig: { thinkingBudget: 0, includeThoughts: false }, + }, + }), + }), + ); + }); + }); +}); diff --git a/packages/cli/src/utils/modelCheck.ts b/packages/cli/src/utils/modelCheck.ts new file mode 100644 index 00000000..1634656e --- /dev/null +++ b/packages/cli/src/utils/modelCheck.ts @@ -0,0 +1,75 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, +} from '../config/config.js'; + +export interface EffectiveModelCheckResult { + effectiveModel: string; + switched: boolean; + originalModelIfSwitched?: string; +} + +/** + * Checks if the default "pro" model is rate-limited and returns a fallback "flash" + * model if necessary. This function is designed to be silent. + * @param apiKey The API key to use for the check. + * @param currentConfiguredModel The model currently configured in settings. + * @returns An object indicating the model to use, whether a switch occurred, + * and the original model if a switch happened. + */ +export async function getEffectiveModel( + apiKey: string, + currentConfiguredModel: string, +): Promise { + if (currentConfiguredModel !== DEFAULT_GEMINI_MODEL) { + // Only check if the user is trying to use the specific pro model we want to fallback from. + return { effectiveModel: currentConfiguredModel, switched: false }; + } + + const modelToTest = DEFAULT_GEMINI_MODEL; + const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + const endpoint = `https://generativelanguage.googleapis.com/v1beta/models/${modelToTest}:generateContent?key=${apiKey}`; + const body = JSON.stringify({ + contents: [{ parts: [{ text: 'test' }] }], + generationConfig: { + maxOutputTokens: 1, + temperature: 0, + topK: 1, + thinkingConfig: { thinkingBudget: 0, includeThoughts: false }, + }, + }); + + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), 2000); // 500ms timeout for the request + + try { + const response = await fetch(endpoint, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body, + signal: controller.signal, + }); + + clearTimeout(timeoutId); + + if (response.status === 429) { + return { + effectiveModel: fallbackModel, + switched: true, + originalModelIfSwitched: modelToTest, + }; + } + // For any other case (success, other error codes), we stick to the original model. + return { effectiveModel: currentConfiguredModel, switched: false }; + } catch (_error) { + clearTimeout(timeoutId); + // On timeout or any other fetch error, stick to the original model. + return { effectiveModel: currentConfiguredModel, switched: false }; + } +} diff --git a/packages/cli/tsconfig.json b/packages/cli/tsconfig.json index 1ffbb402..a186b89d 100644 --- a/packages/cli/tsconfig.json +++ b/packages/cli/tsconfig.json @@ -7,6 +7,6 @@ "types": ["node", "vitest/globals"] }, "include": ["index.ts", "src/**/*.ts", "src/**/*.tsx", "src/**/*.json"], - "exclude": ["node_modules", "dist"], + "exclude": ["node_modules", "dist", "src/**/*.test.ts"], "references": [{ "path": "../core" }] } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index bd28c864..ba95e490 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -9,6 +9,7 @@ export * from './config/config.js'; // Export Core Logic export * from './core/client.js'; +export * from './core/geminiChat.js'; export * from './core/logger.js'; export * from './core/prompts.js'; export * from './core/turn.js';