From ab63a5f183ca6f787971219190db326043f6a502 Mon Sep 17 00:00:00 2001 From: SunskyXH Date: Fri, 4 Jul 2025 04:43:48 +0900 Subject: [PATCH] fix(client): get model from config in flashFallbackHandler (#2118) Co-authored-by: Jacob Richman --- packages/core/src/core/client.test.ts | 110 ++++++++++++++++++++++++++ packages/core/src/core/client.ts | 22 +++--- 2 files changed, 122 insertions(+), 10 deletions(-) diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 0adbf986..dc3b8455 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -687,4 +687,114 @@ describe('Gemini Client (client.ts)', () => { ); }); }); + + describe('generateContent', () => { + it('should use current model from config for content generation', async () => { + const initialModel = client['config'].getModel(); + const contents = [{ role: 'user', parts: [{ text: 'test' }] }]; + const currentModel = initialModel + '-changed'; + + vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel); + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + await client.generateContent(contents, {}, new AbortController().signal); + + expect(mockGenerateContentFn).not.toHaveBeenCalledWith({ + model: initialModel, + config: expect.any(Object), + contents, + }); + expect(mockGenerateContentFn).toHaveBeenCalledWith({ + model: currentModel, + config: expect.any(Object), + contents, + }); + }); + }); + + describe('tryCompressChat', () => { + it('should use current model from config for token counting after sendMessage', async () => { + const initialModel = client['config'].getModel(); + + const mockCountTokens = vi + .fn() + .mockResolvedValueOnce({ totalTokens: 100000 }) + .mockResolvedValueOnce({ totalTokens: 5000 }); + + const mockSendMessage = vi.fn().mockResolvedValue({ text: 'Summary' }); + + const mockChatHistory = [ + { role: 'user', parts: [{ text: 'Long conversation' }] }, + { role: 'model', parts: [{ text: 'Long response' }] }, + ]; + + const mockChat: Partial = { + getHistory: vi.fn().mockReturnValue(mockChatHistory), + sendMessage: mockSendMessage, + }; + + const mockGenerator: Partial = { + countTokens: mockCountTokens, + }; + + // mock the model has been changed between calls of `countTokens` + const firstCurrentModel = initialModel + '-changed-1'; + const secondCurrentModel = initialModel + '-changed-2'; + vi.spyOn(client['config'], 'getModel') + .mockReturnValueOnce(firstCurrentModel) + .mockReturnValueOnce(secondCurrentModel); + + client['chat'] = mockChat as GeminiChat; + client['contentGenerator'] = mockGenerator as ContentGenerator; + client['startChat'] = vi.fn().mockResolvedValue(mockChat); + + const result = await client.tryCompressChat(true); + + expect(mockCountTokens).toHaveBeenCalledTimes(2); + expect(mockCountTokens).toHaveBeenNthCalledWith(1, { + model: firstCurrentModel, + contents: mockChatHistory, + }); + expect(mockCountTokens).toHaveBeenNthCalledWith(2, { + model: secondCurrentModel, + contents: expect.any(Array), + }); + + expect(result).toEqual({ + originalTokenCount: 100000, + newTokenCount: 5000, + }); + }); + }); + + describe('handleFlashFallback', () => { + it('should use current model from config when checking for fallback', async () => { + const initialModel = client['config'].getModel(); + const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + + // mock config been changed + const currentModel = initialModel + '-changed'; + vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel); + + const mockFallbackHandler = vi.fn().mockResolvedValue(true); + client['config'].flashFallbackHandler = mockFallbackHandler; + client['config'].setModel = vi.fn(); + + const result = await client['handleFlashFallback']( + AuthType.LOGIN_WITH_GOOGLE, + ); + + expect(result).toBe(fallbackModel); + + expect(mockFallbackHandler).toHaveBeenCalledWith( + currentModel, + fallbackModel, + ); + }); + }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index b39b10a0..69ed0dff 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -48,7 +48,6 @@ function isThinkingSupported(model: string) { export class GeminiClient { private chat?: GeminiChat; private contentGenerator?: ContentGenerator; - private model: string; private embeddingModel: string; private generateContentConfig: GenerateContentConfig = { temperature: 0, @@ -62,7 +61,6 @@ export class GeminiClient { setGlobalDispatcher(new ProxyAgent(config.getProxy() as string)); } - this.model = config.getModel(); this.embeddingModel = config.getEmbeddingModel(); } @@ -187,7 +185,9 @@ export class GeminiClient { try { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); - const generateContentConfigWithThinking = isThinkingSupported(this.model) + const generateContentConfigWithThinking = isThinkingSupported( + this.config.getModel(), + ) ? { ...this.generateContentConfig, thinkingConfig: { @@ -345,7 +345,7 @@ export class GeminiClient { generationConfig: GenerateContentConfig, abortSignal: AbortSignal, ): Promise { - const modelToUse = this.model; + const modelToUse = this.config.getModel(); const configToUse: GenerateContentConfig = { ...this.generateContentConfig, ...generationConfig, @@ -439,13 +439,15 @@ export class GeminiClient { return null; } + const model = this.config.getModel(); + let { totalTokens: originalTokenCount } = await this.getContentGenerator().countTokens({ - model: this.model, + model, contents: curatedHistory, }); if (originalTokenCount === undefined) { - console.warn(`Could not determine token count for model ${this.model}.`); + console.warn(`Could not determine token count for model ${model}.`); originalTokenCount = 0; } @@ -453,7 +455,7 @@ export class GeminiClient { if ( !force && originalTokenCount < - this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(this.model) + this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(model) ) { return null; } @@ -479,7 +481,8 @@ export class GeminiClient { const { totalTokens: newTokenCount } = await this.getContentGenerator().countTokens({ - model: this.model, + // model might change after calling `sendMessage`, so we get the newest value from config + model: this.config.getModel(), contents: this.getChat().getHistory(), }); if (newTokenCount === undefined) { @@ -503,7 +506,7 @@ export class GeminiClient { return null; } - const currentModel = this.model; + const currentModel = this.config.getModel(); const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; // Don't fallback if already using Flash model @@ -518,7 +521,6 @@ export class GeminiClient { const accepted = await fallbackHandler(currentModel, fallbackModel); if (accepted) { this.config.setModel(fallbackModel); - this.model = fallbackModel; return fallbackModel; } } catch (error) {