feat: Add model selection logic (#1678)
This commit is contained in:
parent
121bba3464
commit
24ccc9c457
|
@ -14,6 +14,7 @@ import {
|
|||
} from '@google/genai';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { setSimulate429 } from '../utils/testUtils.js';
|
||||
|
||||
// Mocks
|
||||
|
@ -38,11 +39,14 @@ describe('GeminiChat', () => {
|
|||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getContentGeneratorConfig: () => ({
|
||||
authType: 'oauth-personal',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
model: 'test-model',
|
||||
}),
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
setModel: vi.fn(),
|
||||
getGeminiClient: vi.fn().mockReturnValue({
|
||||
generateJson: vi.fn().mockResolvedValue({ model: 'pro' }),
|
||||
}),
|
||||
flashFallbackHandler: undefined,
|
||||
} as unknown as Config;
|
||||
|
||||
|
@ -110,7 +114,7 @@ describe('GeminiChat', () => {
|
|||
await chat.sendMessageStream({ message: 'hello' });
|
||||
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
|
||||
model: 'gemini-pro',
|
||||
model: 'gemini-2.5-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
config: {},
|
||||
});
|
||||
|
|
|
@ -34,7 +34,10 @@ import {
|
|||
ApiRequestEvent,
|
||||
ApiResponseEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '../config/models.js';
|
||||
|
||||
/**
|
||||
* Returns true if the response is valid, false otherwise.
|
||||
|
@ -346,14 +349,20 @@ export class GeminiChat {
|
|||
await this.sendPromise;
|
||||
const userContent = createUserContent(params.message);
|
||||
const requestContents = this.getHistory(true).concat(userContent);
|
||||
this._logApiRequest(requestContents, this.config.getModel());
|
||||
|
||||
const model = await this._selectModel(
|
||||
requestContents,
|
||||
params.config?.abortSignal ?? new AbortController().signal,
|
||||
);
|
||||
|
||||
this._logApiRequest(requestContents, model);
|
||||
|
||||
const startTime = Date.now();
|
||||
|
||||
try {
|
||||
const apiCall = () =>
|
||||
this.contentGenerator.generateContentStream({
|
||||
model: this.config.getModel(),
|
||||
model,
|
||||
contents: requestContents,
|
||||
config: { ...this.generationConfig, ...params.config },
|
||||
});
|
||||
|
@ -397,6 +406,82 @@ export class GeminiChat {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the model to use for the request.
|
||||
*
|
||||
* This is a placeholder for now.
|
||||
*/
|
||||
private async _selectModel(
|
||||
history: Content[],
|
||||
signal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const currentModel = this.config.getModel();
|
||||
if (currentModel === DEFAULT_GEMINI_FLASH_MODEL) {
|
||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||
}
|
||||
|
||||
if (
|
||||
history.length < 5 &&
|
||||
this.config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI
|
||||
) {
|
||||
// There's currently a bug where for Gemini API key usage if we try and use flash as one of the first
|
||||
// requests in our sequence that it will return an empty token.
|
||||
return DEFAULT_GEMINI_MODEL;
|
||||
}
|
||||
|
||||
const flashIndicator = 'flash';
|
||||
const proIndicator = 'pro';
|
||||
const modelChoicePrompt = `You are a super-intelligent router that decides which model to use for a given request. You have two models to choose from: "${flashIndicator}" and "${proIndicator}". "${flashIndicator}" is a smaller and faster model that is good for simple or well defined requests. "${proIndicator}" is a larger and slower model that is good for complex or undefined requests.
|
||||
|
||||
Based on the user request, which model should be used? Respond with a JSON object that contains a single field, \`model\`, whose value is the name of the model to be used.
|
||||
|
||||
For example, if you think "${flashIndicator}" should be used, respond with: { "model": "${flashIndicator}" }`;
|
||||
const modelChoiceContent: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: modelChoicePrompt }],
|
||||
},
|
||||
];
|
||||
|
||||
const client = this.config.getGeminiClient();
|
||||
try {
|
||||
const choice = await client.generateJson(
|
||||
[...history, ...modelChoiceContent],
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
model: {
|
||||
type: 'string',
|
||||
enum: [flashIndicator, proIndicator],
|
||||
},
|
||||
},
|
||||
required: ['model'],
|
||||
},
|
||||
signal,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
{
|
||||
temperature: 0,
|
||||
maxOutputTokens: 25,
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 0,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
switch (choice.model) {
|
||||
case flashIndicator:
|
||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||
case proIndicator:
|
||||
return DEFAULT_GEMINI_MODEL;
|
||||
default:
|
||||
return currentModel;
|
||||
}
|
||||
} catch (_e) {
|
||||
// If the model selection fails, just use the default flash model.
|
||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the chat history.
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue