fix(client): get model from config in flashFallbackHandler (#2118)
Co-authored-by: Jacob Richman <jacob314@gmail.com>
This commit is contained in:
parent
64767c52fe
commit
ab63a5f183
|
@ -687,4 +687,114 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('generateContent', () => {
|
||||||
|
it('should use current model from config for content generation', async () => {
|
||||||
|
const initialModel = client['config'].getModel();
|
||||||
|
const contents = [{ role: 'user', parts: [{ text: 'test' }] }];
|
||||||
|
const currentModel = initialModel + '-changed';
|
||||||
|
|
||||||
|
vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
|
||||||
|
|
||||||
|
const mockGenerator: Partial<ContentGenerator> = {
|
||||||
|
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
|
||||||
|
generateContent: mockGenerateContentFn,
|
||||||
|
};
|
||||||
|
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||||
|
|
||||||
|
await client.generateContent(contents, {}, new AbortController().signal);
|
||||||
|
|
||||||
|
expect(mockGenerateContentFn).not.toHaveBeenCalledWith({
|
||||||
|
model: initialModel,
|
||||||
|
config: expect.any(Object),
|
||||||
|
contents,
|
||||||
|
});
|
||||||
|
expect(mockGenerateContentFn).toHaveBeenCalledWith({
|
||||||
|
model: currentModel,
|
||||||
|
config: expect.any(Object),
|
||||||
|
contents,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('tryCompressChat', () => {
|
||||||
|
it('should use current model from config for token counting after sendMessage', async () => {
|
||||||
|
const initialModel = client['config'].getModel();
|
||||||
|
|
||||||
|
const mockCountTokens = vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValueOnce({ totalTokens: 100000 })
|
||||||
|
.mockResolvedValueOnce({ totalTokens: 5000 });
|
||||||
|
|
||||||
|
const mockSendMessage = vi.fn().mockResolvedValue({ text: 'Summary' });
|
||||||
|
|
||||||
|
const mockChatHistory = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Long conversation' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'Long response' }] },
|
||||||
|
];
|
||||||
|
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
getHistory: vi.fn().mockReturnValue(mockChatHistory),
|
||||||
|
sendMessage: mockSendMessage,
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockGenerator: Partial<ContentGenerator> = {
|
||||||
|
countTokens: mockCountTokens,
|
||||||
|
};
|
||||||
|
|
||||||
|
// mock the model has been changed between calls of `countTokens`
|
||||||
|
const firstCurrentModel = initialModel + '-changed-1';
|
||||||
|
const secondCurrentModel = initialModel + '-changed-2';
|
||||||
|
vi.spyOn(client['config'], 'getModel')
|
||||||
|
.mockReturnValueOnce(firstCurrentModel)
|
||||||
|
.mockReturnValueOnce(secondCurrentModel);
|
||||||
|
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||||
|
client['startChat'] = vi.fn().mockResolvedValue(mockChat);
|
||||||
|
|
||||||
|
const result = await client.tryCompressChat(true);
|
||||||
|
|
||||||
|
expect(mockCountTokens).toHaveBeenCalledTimes(2);
|
||||||
|
expect(mockCountTokens).toHaveBeenNthCalledWith(1, {
|
||||||
|
model: firstCurrentModel,
|
||||||
|
contents: mockChatHistory,
|
||||||
|
});
|
||||||
|
expect(mockCountTokens).toHaveBeenNthCalledWith(2, {
|
||||||
|
model: secondCurrentModel,
|
||||||
|
contents: expect.any(Array),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
originalTokenCount: 100000,
|
||||||
|
newTokenCount: 5000,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('handleFlashFallback', () => {
|
||||||
|
it('should use current model from config when checking for fallback', async () => {
|
||||||
|
const initialModel = client['config'].getModel();
|
||||||
|
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
|
||||||
|
// mock config been changed
|
||||||
|
const currentModel = initialModel + '-changed';
|
||||||
|
vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
|
||||||
|
|
||||||
|
const mockFallbackHandler = vi.fn().mockResolvedValue(true);
|
||||||
|
client['config'].flashFallbackHandler = mockFallbackHandler;
|
||||||
|
client['config'].setModel = vi.fn();
|
||||||
|
|
||||||
|
const result = await client['handleFlashFallback'](
|
||||||
|
AuthType.LOGIN_WITH_GOOGLE,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toBe(fallbackModel);
|
||||||
|
|
||||||
|
expect(mockFallbackHandler).toHaveBeenCalledWith(
|
||||||
|
currentModel,
|
||||||
|
fallbackModel,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -48,7 +48,6 @@ function isThinkingSupported(model: string) {
|
||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
private chat?: GeminiChat;
|
private chat?: GeminiChat;
|
||||||
private contentGenerator?: ContentGenerator;
|
private contentGenerator?: ContentGenerator;
|
||||||
private model: string;
|
|
||||||
private embeddingModel: string;
|
private embeddingModel: string;
|
||||||
private generateContentConfig: GenerateContentConfig = {
|
private generateContentConfig: GenerateContentConfig = {
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
|
@ -62,7 +61,6 @@ export class GeminiClient {
|
||||||
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
|
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.model = config.getModel();
|
|
||||||
this.embeddingModel = config.getEmbeddingModel();
|
this.embeddingModel = config.getEmbeddingModel();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,7 +185,9 @@ export class GeminiClient {
|
||||||
try {
|
try {
|
||||||
const userMemory = this.config.getUserMemory();
|
const userMemory = this.config.getUserMemory();
|
||||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||||
const generateContentConfigWithThinking = isThinkingSupported(this.model)
|
const generateContentConfigWithThinking = isThinkingSupported(
|
||||||
|
this.config.getModel(),
|
||||||
|
)
|
||||||
? {
|
? {
|
||||||
...this.generateContentConfig,
|
...this.generateContentConfig,
|
||||||
thinkingConfig: {
|
thinkingConfig: {
|
||||||
|
@ -345,7 +345,7 @@ export class GeminiClient {
|
||||||
generationConfig: GenerateContentConfig,
|
generationConfig: GenerateContentConfig,
|
||||||
abortSignal: AbortSignal,
|
abortSignal: AbortSignal,
|
||||||
): Promise<GenerateContentResponse> {
|
): Promise<GenerateContentResponse> {
|
||||||
const modelToUse = this.model;
|
const modelToUse = this.config.getModel();
|
||||||
const configToUse: GenerateContentConfig = {
|
const configToUse: GenerateContentConfig = {
|
||||||
...this.generateContentConfig,
|
...this.generateContentConfig,
|
||||||
...generationConfig,
|
...generationConfig,
|
||||||
|
@ -439,13 +439,15 @@ export class GeminiClient {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const model = this.config.getModel();
|
||||||
|
|
||||||
let { totalTokens: originalTokenCount } =
|
let { totalTokens: originalTokenCount } =
|
||||||
await this.getContentGenerator().countTokens({
|
await this.getContentGenerator().countTokens({
|
||||||
model: this.model,
|
model,
|
||||||
contents: curatedHistory,
|
contents: curatedHistory,
|
||||||
});
|
});
|
||||||
if (originalTokenCount === undefined) {
|
if (originalTokenCount === undefined) {
|
||||||
console.warn(`Could not determine token count for model ${this.model}.`);
|
console.warn(`Could not determine token count for model ${model}.`);
|
||||||
originalTokenCount = 0;
|
originalTokenCount = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -453,7 +455,7 @@ export class GeminiClient {
|
||||||
if (
|
if (
|
||||||
!force &&
|
!force &&
|
||||||
originalTokenCount <
|
originalTokenCount <
|
||||||
this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(this.model)
|
this.TOKEN_THRESHOLD_FOR_SUMMARIZATION * tokenLimit(model)
|
||||||
) {
|
) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -479,7 +481,8 @@ export class GeminiClient {
|
||||||
|
|
||||||
const { totalTokens: newTokenCount } =
|
const { totalTokens: newTokenCount } =
|
||||||
await this.getContentGenerator().countTokens({
|
await this.getContentGenerator().countTokens({
|
||||||
model: this.model,
|
// model might change after calling `sendMessage`, so we get the newest value from config
|
||||||
|
model: this.config.getModel(),
|
||||||
contents: this.getChat().getHistory(),
|
contents: this.getChat().getHistory(),
|
||||||
});
|
});
|
||||||
if (newTokenCount === undefined) {
|
if (newTokenCount === undefined) {
|
||||||
|
@ -503,7 +506,7 @@ export class GeminiClient {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const currentModel = this.model;
|
const currentModel = this.config.getModel();
|
||||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
|
||||||
// Don't fallback if already using Flash model
|
// Don't fallback if already using Flash model
|
||||||
|
@ -518,7 +521,6 @@ export class GeminiClient {
|
||||||
const accepted = await fallbackHandler(currentModel, fallbackModel);
|
const accepted = await fallbackHandler(currentModel, fallbackModel);
|
||||||
if (accepted) {
|
if (accepted) {
|
||||||
this.config.setModel(fallbackModel);
|
this.config.setModel(fallbackModel);
|
||||||
this.model = fallbackModel;
|
|
||||||
return fallbackModel;
|
return fallbackModel;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
Loading…
Reference in New Issue