fix(client): get model from config in flashFallbackHandler (#2118)

Co-authored-by: Jacob Richman <jacob314@gmail.com>
This commit is contained in:
SunskyXH 2025-07-04 04:43:48 +09:00 committed by GitHub
parent 64767c52fe
commit ab63a5f183
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 122 additions and 10 deletions

View File

@ -687,4 +687,114 @@ describe('Gemini Client (client.ts)', () => {
);
});
});
describe('generateContent', () => {
it('should use current model from config for content generation', async () => {
const initialModel = client['config'].getModel();
const contents = [{ role: 'user', parts: [{ text: 'test' }] }];
const currentModel = initialModel + '-changed';
vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateContent(contents, {}, new AbortController().signal);
expect(mockGenerateContentFn).not.toHaveBeenCalledWith({
model: initialModel,
config: expect.any(Object),
contents,
});
expect(mockGenerateContentFn).toHaveBeenCalledWith({
model: currentModel,
config: expect.any(Object),
contents,
});
});
});
describe('tryCompressChat', () => {
it('should use current model from config for token counting after sendMessage', async () => {
const initialModel = client['config'].getModel();
const mockCountTokens = vi
.fn()
.mockResolvedValueOnce({ totalTokens: 100000 })
.mockResolvedValueOnce({ totalTokens: 5000 });
const mockSendMessage = vi.fn().mockResolvedValue({ text: 'Summary' });
const mockChatHistory = [
{ role: 'user', parts: [{ text: 'Long conversation' }] },
{ role: 'model', parts: [{ text: 'Long response' }] },
];
const mockChat: Partial<GeminiChat> = {
getHistory: vi.fn().mockReturnValue(mockChatHistory),
sendMessage: mockSendMessage,
};
const mockGenerator: Partial<ContentGenerator> = {
countTokens: mockCountTokens,
};
// mock the model has been changed between calls of `countTokens`
const firstCurrentModel = initialModel + '-changed-1';
const secondCurrentModel = initialModel + '-changed-2';
vi.spyOn(client['config'], 'getModel')
.mockReturnValueOnce(firstCurrentModel)
.mockReturnValueOnce(secondCurrentModel);
client['chat'] = mockChat as GeminiChat;
client['contentGenerator'] = mockGenerator as ContentGenerator;
client['startChat'] = vi.fn().mockResolvedValue(mockChat);
const result = await client.tryCompressChat(true);
expect(mockCountTokens).toHaveBeenCalledTimes(2);
expect(mockCountTokens).toHaveBeenNthCalledWith(1, {
model: firstCurrentModel,
contents: mockChatHistory,
});
expect(mockCountTokens).toHaveBeenNthCalledWith(2, {
model: secondCurrentModel,
contents: expect.any(Array),
});
expect(result).toEqual({
originalTokenCount: 100000,
newTokenCount: 5000,
});
});
});
describe('handleFlashFallback', () => {
it('should use current model from config when checking for fallback', async () => {
const initialModel = client['config'].getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// mock config been changed
const currentModel = initialModel + '-changed';
vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
const mockFallbackHandler = vi.fn().mockResolvedValue(true);
client['config'].flashFallbackHandler = mockFallbackHandler;
client['config'].setModel = vi.fn();
const result = await client['handleFlashFallback'](
AuthType.LOGIN_WITH_GOOGLE,
);
expect(result).toBe(fallbackModel);
expect(mockFallbackHandler).toHaveBeenCalledWith(
currentModel,
fallbackModel,
);
});
});
});

View File

@ -48,7 +48,6 @@ function isThinkingSupported(model: string) {
export class GeminiClient {
private chat?: GeminiChat;
private contentGenerator?: ContentGenerator;
private model: string;
private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = {
temperature: 0,
@ -62,7 +61,6 @@ export class GeminiClient {
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
}
this.model = config.getModel();
this.embeddingModel = config.getEmbeddingModel();
}
@ -187,7 +185,9 @@ export class GeminiClient {
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
const generateContentConfigWithThinking = isThinkingSupported(this.model)
const generateContentConfigWithThinking = isThinkingSupported(
this.config.getModel(),
)
? {
...this.generateContentConfig,
thinkingConfig: {
@ -345,7 +345,7 @@ export class GeminiClient {
generationConfig: GenerateContentConfig,
abortSignal: AbortSignal,
): Promise<GenerateContentResponse> {
const modelToUse = this.model;
const modelToUse = this.config.getModel();
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
@ -439,13 +439,15 @@ export class GeminiClient {
return null;
}
const model = this.config.getModel();
let { totalTokens: originalTokenCount } =
await this.getContentGenerator().countTokens({
model: this.model,
model,
contents: curatedHistory,
});
if (originalTokenCount === undefined) {
console.warn(`Could not determine token count for model ${this.model}.`);
console.warn(`Could not determine token count for model ${model}.`);
originalTokenCount = 0;
}
@ -453,7 +455,7 @@ export class GeminiClient {
if (
!force &&
originalTokenCount <
this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(this.model)
this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(model)
) {
return null;
}
@ -479,7 +481,8 @@ export class GeminiClient {
const { totalTokens: newTokenCount } =
await this.getContentGenerator().countTokens({
model: this.model,
// model might change after calling `sendMessage`, so we get the newest value from config
model: this.config.getModel(),
contents: this.getChat().getHistory(),
});
if (newTokenCount === undefined) {
@ -503,7 +506,7 @@ export class GeminiClient {
return null;
}
const currentModel = this.model;
const currentModel = this.config.getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
@ -518,7 +521,6 @@ export class GeminiClient {
const accepted = await fallbackHandler(currentModel, fallbackModel);
if (accepted) {
this.config.setModel(fallbackModel);
this.model = fallbackModel;
return fallbackModel;
}
} catch (error) {