feat: Add model selection logic (#1678)
This commit is contained in:
parent
121bba3464
commit
24ccc9c457
|
@ -14,6 +14,7 @@ import {
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { GeminiChat } from './geminiChat.js';
|
import { GeminiChat } from './geminiChat.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
|
import { AuthType } from '../core/contentGenerator.js';
|
||||||
import { setSimulate429 } from '../utils/testUtils.js';
|
import { setSimulate429 } from '../utils/testUtils.js';
|
||||||
|
|
||||||
// Mocks
|
// Mocks
|
||||||
|
@ -38,11 +39,14 @@ describe('GeminiChat', () => {
|
||||||
getUsageStatisticsEnabled: () => true,
|
getUsageStatisticsEnabled: () => true,
|
||||||
getDebugMode: () => false,
|
getDebugMode: () => false,
|
||||||
getContentGeneratorConfig: () => ({
|
getContentGeneratorConfig: () => ({
|
||||||
authType: 'oauth-personal',
|
authType: AuthType.USE_GEMINI,
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
}),
|
}),
|
||||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||||
setModel: vi.fn(),
|
setModel: vi.fn(),
|
||||||
|
getGeminiClient: vi.fn().mockReturnValue({
|
||||||
|
generateJson: vi.fn().mockResolvedValue({ model: 'pro' }),
|
||||||
|
}),
|
||||||
flashFallbackHandler: undefined,
|
flashFallbackHandler: undefined,
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
|
|
||||||
|
@ -110,7 +114,7 @@ describe('GeminiChat', () => {
|
||||||
await chat.sendMessageStream({ message: 'hello' });
|
await chat.sendMessageStream({ message: 'hello' });
|
||||||
|
|
||||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
|
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
|
||||||
model: 'gemini-pro',
|
model: 'gemini-2.5-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||||
config: {},
|
config: {},
|
||||||
});
|
});
|
||||||
|
|
|
@ -34,7 +34,10 @@ import {
|
||||||
ApiRequestEvent,
|
ApiRequestEvent,
|
||||||
ApiResponseEvent,
|
ApiResponseEvent,
|
||||||
} from '../telemetry/types.js';
|
} from '../telemetry/types.js';
|
||||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
import {
|
||||||
|
DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
|
DEFAULT_GEMINI_MODEL,
|
||||||
|
} from '../config/models.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if the response is valid, false otherwise.
|
* Returns true if the response is valid, false otherwise.
|
||||||
|
@ -346,14 +349,20 @@ export class GeminiChat {
|
||||||
await this.sendPromise;
|
await this.sendPromise;
|
||||||
const userContent = createUserContent(params.message);
|
const userContent = createUserContent(params.message);
|
||||||
const requestContents = this.getHistory(true).concat(userContent);
|
const requestContents = this.getHistory(true).concat(userContent);
|
||||||
this._logApiRequest(requestContents, this.config.getModel());
|
|
||||||
|
const model = await this._selectModel(
|
||||||
|
requestContents,
|
||||||
|
params.config?.abortSignal ?? new AbortController().signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
this._logApiRequest(requestContents, model);
|
||||||
|
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.contentGenerator.generateContentStream({
|
this.contentGenerator.generateContentStream({
|
||||||
model: this.config.getModel(),
|
model,
|
||||||
contents: requestContents,
|
contents: requestContents,
|
||||||
config: { ...this.generationConfig, ...params.config },
|
config: { ...this.generationConfig, ...params.config },
|
||||||
});
|
});
|
||||||
|
@ -397,6 +406,82 @@ export class GeminiChat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Selects the model to use for the request.
|
||||||
|
*
|
||||||
|
* This is a placeholder for now.
|
||||||
|
*/
|
||||||
|
private async _selectModel(
|
||||||
|
history: Content[],
|
||||||
|
signal: AbortSignal,
|
||||||
|
): Promise<string> {
|
||||||
|
const currentModel = this.config.getModel();
|
||||||
|
if (currentModel === DEFAULT_GEMINI_FLASH_MODEL) {
|
||||||
|
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
history.length < 5 &&
|
||||||
|
this.config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI
|
||||||
|
) {
|
||||||
|
// There's currently a bug where for Gemini API key usage if we try and use flash as one of the first
|
||||||
|
// requests in our sequence that it will return an empty token.
|
||||||
|
return DEFAULT_GEMINI_MODEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
const flashIndicator = 'flash';
|
||||||
|
const proIndicator = 'pro';
|
||||||
|
const modelChoicePrompt = `You are a super-intelligent router that decides which model to use for a given request. You have two models to choose from: "${flashIndicator}" and "${proIndicator}". "${flashIndicator}" is a smaller and faster model that is good for simple or well defined requests. "${proIndicator}" is a larger and slower model that is good for complex or undefined requests.
|
||||||
|
|
||||||
|
Based on the user request, which model should be used? Respond with a JSON object that contains a single field, \`model\`, whose value is the name of the model to be used.
|
||||||
|
|
||||||
|
For example, if you think "${flashIndicator}" should be used, respond with: { "model": "${flashIndicator}" }`;
|
||||||
|
const modelChoiceContent: Content[] = [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: modelChoicePrompt }],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const client = this.config.getGeminiClient();
|
||||||
|
try {
|
||||||
|
const choice = await client.generateJson(
|
||||||
|
[...history, ...modelChoiceContent],
|
||||||
|
{
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
model: {
|
||||||
|
type: 'string',
|
||||||
|
enum: [flashIndicator, proIndicator],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ['model'],
|
||||||
|
},
|
||||||
|
signal,
|
||||||
|
DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
|
{
|
||||||
|
temperature: 0,
|
||||||
|
maxOutputTokens: 25,
|
||||||
|
thinkingConfig: {
|
||||||
|
thinkingBudget: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
switch (choice.model) {
|
||||||
|
case flashIndicator:
|
||||||
|
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
case proIndicator:
|
||||||
|
return DEFAULT_GEMINI_MODEL;
|
||||||
|
default:
|
||||||
|
return currentModel;
|
||||||
|
}
|
||||||
|
} catch (_e) {
|
||||||
|
// If the model selection fails, just use the default flash model.
|
||||||
|
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the chat history.
|
* Returns the chat history.
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue