diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 6170f319..df655b59 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -196,7 +196,6 @@ export class GeminiClient { return new GeminiChat( this.config, this.getContentGenerator(), - this.model, { systemInstruction, ...generateContentConfigWithThinking, diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 18b3729e..bfaeb8f6 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -25,34 +25,36 @@ const mockModelsModule = { batchEmbedContents: vi.fn(), } as unknown as Models; -const mockConfig = { - getSessionId: () => 'test-session-id', - getTelemetryLogPromptsEnabled: () => true, - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getContentGeneratorConfig: () => ({ - authType: 'oauth-personal', - model: 'test-model', - }), - setModel: vi.fn(), - flashFallbackHandler: undefined, -} as unknown as Config; - describe('GeminiChat', () => { let chat: GeminiChat; - const model = 'gemini-pro'; + let mockConfig: Config; const config: GenerateContentConfig = {}; beforeEach(() => { vi.clearAllMocks(); + mockConfig = { + getSessionId: () => 'test-session-id', + getTelemetryLogPromptsEnabled: () => true, + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getContentGeneratorConfig: () => ({ + authType: 'oauth-personal', + model: 'test-model', + }), + getModel: vi.fn().mockReturnValue('gemini-pro'), + setModel: vi.fn(), + flashFallbackHandler: undefined, + } as unknown as Config; + // Disable 429 simulation for tests setSimulate429(false); // Reset history for each test by creating a new instance - chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []); + chat = new GeminiChat(mockConfig, mockModelsModule, config, []); }); afterEach(() => { vi.restoreAllMocks(); + vi.resetAllMocks(); }); describe('sendMessage', () => { @@ -203,7 +205,7 @@ describe('GeminiChat', () => { chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed // Reset and set up a more realistic scenario for merging with existing history - chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []); + chat = new GeminiChat(mockConfig, mockModelsModule, config, []); const firstUserInput: Content = { role: 'user', parts: [{ text: 'First user input' }], @@ -246,7 +248,7 @@ describe('GeminiChat', () => { role: 'model', parts: [{ text: 'Initial model answer.' }], }; - chat = new GeminiChat(mockConfig, mockModelsModule, model, config, [ + chat = new GeminiChat(mockConfig, mockModelsModule, config, [ initialUser, initialModel, ]); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index ce5accf4..ac4f4898 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -138,7 +138,6 @@ export class GeminiChat { constructor( private readonly config: Config, private readonly contentGenerator: ContentGenerator, - private readonly model: string, private readonly generationConfig: GenerateContentConfig = {}, private history: Content[] = [], ) { @@ -168,7 +167,12 @@ export class GeminiChat { ): Promise { logApiResponse( this.config, - new ApiResponseEvent(this.model, durationMs, usageMetadata, responseText), + new ApiResponseEvent( + this.config.getModel(), + durationMs, + usageMetadata, + responseText, + ), ); } @@ -178,7 +182,12 @@ export class GeminiChat { logApiError( this.config, - new ApiErrorEvent(this.model, errorMessage, durationMs, errorType), + new ApiErrorEvent( + this.config.getModel(), + errorMessage, + durationMs, + errorType, + ), ); } @@ -192,7 +201,7 @@ export class GeminiChat { return null; } - const currentModel = this.model; + const currentModel = this.config.getModel(); const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; // Don't fallback if already using Flash model @@ -244,7 +253,7 @@ export class GeminiChat { const userContent = createUserContent(params.message); const requestContents = this.getHistory(true).concat(userContent); - this._logApiRequest(requestContents, this.model); + this._logApiRequest(requestContents, this.config.getModel()); const startTime = Date.now(); let response: GenerateContentResponse; @@ -252,12 +261,23 @@ export class GeminiChat { try { const apiCall = () => this.contentGenerator.generateContent({ - model: this.model, + model: this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL, contents: requestContents, config: { ...this.generationConfig, ...params.config }, }); - response = await retryWithBackoff(apiCall); + response = await retryWithBackoff(apiCall, { + shouldRetry: (error: Error) => { + if (error && error.message) { + if (error.message.includes('429')) return true; + if (error.message.match(/5\d{2}/)) return true; + } + return false; + }, + onPersistent429: async (authType?: string) => + await this.handleFlashFallback(authType), + authType: this.config.getContentGeneratorConfig()?.authType, + }); const durationMs = Date.now() - startTime; await this._logApiResponse( durationMs, @@ -326,14 +346,14 @@ export class GeminiChat { await this.sendPromise; const userContent = createUserContent(params.message); const requestContents = this.getHistory(true).concat(userContent); - this._logApiRequest(requestContents, this.model); + this._logApiRequest(requestContents, this.config.getModel()); const startTime = Date.now(); try { const apiCall = () => this.contentGenerator.generateContentStream({ - model: this.model, + model: this.config.getModel(), contents: requestContents, config: { ...this.generationConfig, ...params.config }, }); diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts index 83ce97fd..475b5662 100644 --- a/packages/core/src/utils/nextSpeakerChecker.test.ts +++ b/packages/core/src/utils/nextSpeakerChecker.test.ts @@ -71,7 +71,6 @@ describe('checkNextSpeaker', () => { chatInstance = new GeminiChat( mockConfigInstance, mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel - 'gemini-pro', // model name {}, [], // initial history ); diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts index 1988c02a..a0294c31 100644 --- a/packages/core/src/utils/retry.test.ts +++ b/packages/core/src/utils/retry.test.ts @@ -272,7 +272,7 @@ describe('retryWithBackoff', () => { expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); // Should retry again after fallback - expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback + expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback }); it('should NOT trigger fallback for API key users', async () => { diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index ebe18510..372a7976 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -67,9 +67,9 @@ export async function retryWithBackoff( maxAttempts, initialDelayMs, maxDelayMs, - shouldRetry, onPersistent429, authType, + shouldRetry, } = { ...DEFAULT_RETRY_OPTIONS, ...options, @@ -93,28 +93,30 @@ export async function retryWithBackoff( consecutive429Count = 0; } + // If we have persistent 429s and a fallback callback for OAuth + if ( + consecutive429Count >= 2 && + onPersistent429 && + authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL + ) { + try { + const fallbackModel = await onPersistent429(authType); + if (fallbackModel) { + // Reset attempt counter and try with new model + attempt = 0; + consecutive429Count = 0; + currentDelay = initialDelayMs; + // With the model updated, we continue to the next attempt + continue; + } + } catch (fallbackError) { + // If fallback fails, continue with original error + console.warn('Fallback to Flash model failed:', fallbackError); + } + } + // Check if we've exhausted retries or shouldn't retry if (attempt >= maxAttempts || !shouldRetry(error as Error)) { - // If we have persistent 429s and a fallback callback for OAuth - if ( - consecutive429Count >= 2 && - onPersistent429 && - authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL - ) { - try { - const fallbackModel = await onPersistent429(authType); - if (fallbackModel) { - // Reset attempt counter and try with new model - attempt = 0; - consecutive429Count = 0; - currentDelay = initialDelayMs; - continue; - } - } catch (fallbackError) { - // If fallback fails, continue with original error - console.warn('Fallback to Flash model failed:', fallbackError); - } - } throw error; }