feat: Add model selection logic (#1678)

This commit is contained in:
N. Taylor Mullen 2025-06-26 16:51:32 +02:00 committed by GitHub
parent 121bba3464
commit 24ccc9c457
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 5 deletions

View File

@ -14,6 +14,7 @@ import {
} from '@google/genai';
import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import { setSimulate429 } from '../utils/testUtils.js';
// Mocks
@ -38,11 +39,14 @@ describe('GeminiChat', () => {
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
authType: 'oauth-personal',
authType: AuthType.USE_GEMINI,
model: 'test-model',
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
setModel: vi.fn(),
getGeminiClient: vi.fn().mockReturnValue({
generateJson: vi.fn().mockResolvedValue({ model: 'pro' }),
}),
flashFallbackHandler: undefined,
} as unknown as Config;
@ -110,7 +114,7 @@ describe('GeminiChat', () => {
await chat.sendMessageStream({ message: 'hello' });
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
model: 'gemini-pro',
model: 'gemini-2.5-pro',
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
config: {},
});

View File

@ -34,7 +34,10 @@ import {
ApiRequestEvent,
ApiResponseEvent,
} from '../telemetry/types.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
} from '../config/models.js';
/**
* Returns true if the response is valid, false otherwise.
@ -346,14 +349,20 @@ export class GeminiChat {
await this.sendPromise;
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
this._logApiRequest(requestContents, this.config.getModel());
const model = await this._selectModel(
requestContents,
params.config?.abortSignal ?? new AbortController().signal,
);
this._logApiRequest(requestContents, model);
const startTime = Date.now();
try {
const apiCall = () =>
this.contentGenerator.generateContentStream({
model: this.config.getModel(),
model,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
});
@ -397,6 +406,82 @@ export class GeminiChat {
}
}
/**
* Selects the model to use for the request.
*
* This is a placeholder for now.
*/
private async _selectModel(
history: Content[],
signal: AbortSignal,
): Promise<string> {
const currentModel = this.config.getModel();
if (currentModel === DEFAULT_GEMINI_FLASH_MODEL) {
return DEFAULT_GEMINI_FLASH_MODEL;
}
if (
history.length < 5 &&
this.config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI
) {
// There's currently a bug where for Gemini API key usage if we try and use flash as one of the first
// requests in our sequence that it will return an empty token.
return DEFAULT_GEMINI_MODEL;
}
const flashIndicator = 'flash';
const proIndicator = 'pro';
const modelChoicePrompt = `You are a super-intelligent router that decides which model to use for a given request. You have two models to choose from: "${flashIndicator}" and "${proIndicator}". "${flashIndicator}" is a smaller and faster model that is good for simple or well defined requests. "${proIndicator}" is a larger and slower model that is good for complex or undefined requests.
Based on the user request, which model should be used? Respond with a JSON object that contains a single field, \`model\`, whose value is the name of the model to be used.
For example, if you think "${flashIndicator}" should be used, respond with: { "model": "${flashIndicator}" }`;
const modelChoiceContent: Content[] = [
{
role: 'user',
parts: [{ text: modelChoicePrompt }],
},
];
const client = this.config.getGeminiClient();
try {
const choice = await client.generateJson(
[...history, ...modelChoiceContent],
{
type: 'object',
properties: {
model: {
type: 'string',
enum: [flashIndicator, proIndicator],
},
},
required: ['model'],
},
signal,
DEFAULT_GEMINI_FLASH_MODEL,
{
temperature: 0,
maxOutputTokens: 25,
thinkingConfig: {
thinkingBudget: 0,
},
},
);
switch (choice.model) {
case flashIndicator:
return DEFAULT_GEMINI_FLASH_MODEL;
case proIndicator:
return DEFAULT_GEMINI_MODEL;
default:
return currentModel;
}
} catch (_e) {
// If the model selection fails, just use the default flash model.
return DEFAULT_GEMINI_FLASH_MODEL;
}
}
/**
* Returns the chat history.
*