diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index bfaeb8f6..67fa676d 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -14,6 +14,7 @@ import { } from '@google/genai'; import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; +import { AuthType } from '../core/contentGenerator.js'; import { setSimulate429 } from '../utils/testUtils.js'; // Mocks @@ -38,11 +39,14 @@ describe('GeminiChat', () => { getUsageStatisticsEnabled: () => true, getDebugMode: () => false, getContentGeneratorConfig: () => ({ - authType: 'oauth-personal', + authType: AuthType.USE_GEMINI, model: 'test-model', }), getModel: vi.fn().mockReturnValue('gemini-pro'), setModel: vi.fn(), + getGeminiClient: vi.fn().mockReturnValue({ + generateJson: vi.fn().mockResolvedValue({ model: 'pro' }), + }), flashFallbackHandler: undefined, } as unknown as Config; @@ -110,7 +114,7 @@ describe('GeminiChat', () => { await chat.sendMessageStream({ message: 'hello' }); expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({ - model: 'gemini-pro', + model: 'gemini-2.5-pro', contents: [{ role: 'user', parts: [{ text: 'hello' }] }], config: {}, }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 19b87805..770f8bb6 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -34,7 +34,10 @@ import { ApiRequestEvent, ApiResponseEvent, } from '../telemetry/types.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL, +} from '../config/models.js'; /** * Returns true if the response is valid, false otherwise. @@ -346,14 +349,20 @@ export class GeminiChat { await this.sendPromise; const userContent = createUserContent(params.message); const requestContents = this.getHistory(true).concat(userContent); - this._logApiRequest(requestContents, this.config.getModel()); + + const model = await this._selectModel( + requestContents, + params.config?.abortSignal ?? new AbortController().signal, + ); + + this._logApiRequest(requestContents, model); const startTime = Date.now(); try { const apiCall = () => this.contentGenerator.generateContentStream({ - model: this.config.getModel(), + model, contents: requestContents, config: { ...this.generationConfig, ...params.config }, }); @@ -397,6 +406,82 @@ export class GeminiChat { } } + /** + * Selects the model to use for the request. + * + * This is a placeholder for now. + */ + private async _selectModel( + history: Content[], + signal: AbortSignal, + ): Promise { + const currentModel = this.config.getModel(); + if (currentModel === DEFAULT_GEMINI_FLASH_MODEL) { + return DEFAULT_GEMINI_FLASH_MODEL; + } + + if ( + history.length < 5 && + this.config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI + ) { + // There's currently a bug where for Gemini API key usage if we try and use flash as one of the first + // requests in our sequence that it will return an empty token. + return DEFAULT_GEMINI_MODEL; + } + + const flashIndicator = 'flash'; + const proIndicator = 'pro'; + const modelChoicePrompt = `You are a super-intelligent router that decides which model to use for a given request. You have two models to choose from: "${flashIndicator}" and "${proIndicator}". "${flashIndicator}" is a smaller and faster model that is good for simple or well defined requests. "${proIndicator}" is a larger and slower model that is good for complex or undefined requests. + +Based on the user request, which model should be used? Respond with a JSON object that contains a single field, \`model\`, whose value is the name of the model to be used. + +For example, if you think "${flashIndicator}" should be used, respond with: { "model": "${flashIndicator}" }`; + const modelChoiceContent: Content[] = [ + { + role: 'user', + parts: [{ text: modelChoicePrompt }], + }, + ]; + + const client = this.config.getGeminiClient(); + try { + const choice = await client.generateJson( + [...history, ...modelChoiceContent], + { + type: 'object', + properties: { + model: { + type: 'string', + enum: [flashIndicator, proIndicator], + }, + }, + required: ['model'], + }, + signal, + DEFAULT_GEMINI_FLASH_MODEL, + { + temperature: 0, + maxOutputTokens: 25, + thinkingConfig: { + thinkingBudget: 0, + }, + }, + ); + + switch (choice.model) { + case flashIndicator: + return DEFAULT_GEMINI_FLASH_MODEL; + case proIndicator: + return DEFAULT_GEMINI_MODEL; + default: + return currentModel; + } + } catch (_e) { + // If the model selection fails, just use the default flash model. + return DEFAULT_GEMINI_FLASH_MODEL; + } + } + /** * Returns the chat history. *