parent
c55b15f705
commit
267173c7e8
|
@ -14,7 +14,6 @@ import {
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { GeminiChat } from './geminiChat.js';
|
import { GeminiChat } from './geminiChat.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import { AuthType } from '../core/contentGenerator.js';
|
|
||||||
import { setSimulate429 } from '../utils/testUtils.js';
|
import { setSimulate429 } from '../utils/testUtils.js';
|
||||||
|
|
||||||
// Mocks
|
// Mocks
|
||||||
|
@ -39,14 +38,11 @@ describe('GeminiChat', () => {
|
||||||
getUsageStatisticsEnabled: () => true,
|
getUsageStatisticsEnabled: () => true,
|
||||||
getDebugMode: () => false,
|
getDebugMode: () => false,
|
||||||
getContentGeneratorConfig: () => ({
|
getContentGeneratorConfig: () => ({
|
||||||
authType: AuthType.USE_GEMINI,
|
authType: 'oauth-personal',
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
}),
|
}),
|
||||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||||
setModel: vi.fn(),
|
setModel: vi.fn(),
|
||||||
getGeminiClient: vi.fn().mockReturnValue({
|
|
||||||
generateJson: vi.fn().mockResolvedValue({ model: 'pro' }),
|
|
||||||
}),
|
|
||||||
flashFallbackHandler: undefined,
|
flashFallbackHandler: undefined,
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
|
|
||||||
|
@ -114,7 +110,7 @@ describe('GeminiChat', () => {
|
||||||
await chat.sendMessageStream({ message: 'hello' });
|
await chat.sendMessageStream({ message: 'hello' });
|
||||||
|
|
||||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
|
expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith({
|
||||||
model: 'gemini-2.5-pro',
|
model: 'gemini-pro',
|
||||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||||
config: {},
|
config: {},
|
||||||
});
|
});
|
||||||
|
|
|
@ -34,10 +34,7 @@ import {
|
||||||
ApiRequestEvent,
|
ApiRequestEvent,
|
||||||
ApiResponseEvent,
|
ApiResponseEvent,
|
||||||
} from '../telemetry/types.js';
|
} from '../telemetry/types.js';
|
||||||
import {
|
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||||
DEFAULT_GEMINI_FLASH_MODEL,
|
|
||||||
DEFAULT_GEMINI_MODEL,
|
|
||||||
} from '../config/models.js';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if the response is valid, false otherwise.
|
* Returns true if the response is valid, false otherwise.
|
||||||
|
@ -349,20 +346,14 @@ export class GeminiChat {
|
||||||
await this.sendPromise;
|
await this.sendPromise;
|
||||||
const userContent = createUserContent(params.message);
|
const userContent = createUserContent(params.message);
|
||||||
const requestContents = this.getHistory(true).concat(userContent);
|
const requestContents = this.getHistory(true).concat(userContent);
|
||||||
|
this._logApiRequest(requestContents, this.config.getModel());
|
||||||
const model = await this._selectModel(
|
|
||||||
requestContents,
|
|
||||||
params.config?.abortSignal ?? new AbortController().signal,
|
|
||||||
);
|
|
||||||
|
|
||||||
this._logApiRequest(requestContents, model);
|
|
||||||
|
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.contentGenerator.generateContentStream({
|
this.contentGenerator.generateContentStream({
|
||||||
model,
|
model: this.config.getModel(),
|
||||||
contents: requestContents,
|
contents: requestContents,
|
||||||
config: { ...this.generationConfig, ...params.config },
|
config: { ...this.generationConfig, ...params.config },
|
||||||
});
|
});
|
||||||
|
@ -406,82 +397,6 @@ export class GeminiChat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Selects the model to use for the request.
|
|
||||||
*
|
|
||||||
* This is a placeholder for now.
|
|
||||||
*/
|
|
||||||
private async _selectModel(
|
|
||||||
history: Content[],
|
|
||||||
signal: AbortSignal,
|
|
||||||
): Promise<string> {
|
|
||||||
const currentModel = this.config.getModel();
|
|
||||||
if (currentModel === DEFAULT_GEMINI_FLASH_MODEL) {
|
|
||||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (
|
|
||||||
history.length < 5 &&
|
|
||||||
this.config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI
|
|
||||||
) {
|
|
||||||
// There's currently a bug where for Gemini API key usage if we try and use flash as one of the first
|
|
||||||
// requests in our sequence that it will return an empty token.
|
|
||||||
return DEFAULT_GEMINI_MODEL;
|
|
||||||
}
|
|
||||||
|
|
||||||
const flashIndicator = 'flash';
|
|
||||||
const proIndicator = 'pro';
|
|
||||||
const modelChoicePrompt = `You are a super-intelligent router that decides which model to use for a given request. You have two models to choose from: "${flashIndicator}" and "${proIndicator}". "${flashIndicator}" is a smaller and faster model that is good for simple or well defined requests. "${proIndicator}" is a larger and slower model that is good for complex or undefined requests.
|
|
||||||
|
|
||||||
Based on the user request, which model should be used? Respond with a JSON object that contains a single field, \`model\`, whose value is the name of the model to be used.
|
|
||||||
|
|
||||||
For example, if you think "${flashIndicator}" should be used, respond with: { "model": "${flashIndicator}" }`;
|
|
||||||
const modelChoiceContent: Content[] = [
|
|
||||||
{
|
|
||||||
role: 'user',
|
|
||||||
parts: [{ text: modelChoicePrompt }],
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
const client = this.config.getGeminiClient();
|
|
||||||
try {
|
|
||||||
const choice = await client.generateJson(
|
|
||||||
[...history, ...modelChoiceContent],
|
|
||||||
{
|
|
||||||
type: 'object',
|
|
||||||
properties: {
|
|
||||||
model: {
|
|
||||||
type: 'string',
|
|
||||||
enum: [flashIndicator, proIndicator],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
required: ['model'],
|
|
||||||
},
|
|
||||||
signal,
|
|
||||||
DEFAULT_GEMINI_FLASH_MODEL,
|
|
||||||
{
|
|
||||||
temperature: 0,
|
|
||||||
maxOutputTokens: 25,
|
|
||||||
thinkingConfig: {
|
|
||||||
thinkingBudget: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
switch (choice.model) {
|
|
||||||
case flashIndicator:
|
|
||||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
|
||||||
case proIndicator:
|
|
||||||
return DEFAULT_GEMINI_MODEL;
|
|
||||||
default:
|
|
||||||
return currentModel;
|
|
||||||
}
|
|
||||||
} catch (_e) {
|
|
||||||
// If the model selection fails, just use the default flash model.
|
|
||||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the chat history.
|
* Returns the chat history.
|
||||||
*
|
*
|
||||||
|
|
Loading…
Reference in New Issue