This commit is contained in:
Bryan Morgan 2025-06-25 21:45:38 -04:00 committed by GitHub
parent b6b9923dc3
commit bb797ded7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 72 additions and 50 deletions

View File

@ -196,7 +196,6 @@ export class GeminiClient {
return new GeminiChat( return new GeminiChat(
this.config, this.config,
this.getContentGenerator(), this.getContentGenerator(),
this.model,
{ {
systemInstruction, systemInstruction,
...generateContentConfigWithThinking, ...generateContentConfigWithThinking,

View File

@ -25,7 +25,14 @@ const mockModelsModule = {
batchEmbedContents: vi.fn(), batchEmbedContents: vi.fn(),
} as unknown as Models; } as unknown as Models;
const mockConfig = { describe('GeminiChat', () => {
let chat: GeminiChat;
let mockConfig: Config;
const config: GenerateContentConfig = {};
beforeEach(() => {
vi.clearAllMocks();
mockConfig = {
getSessionId: () => 'test-session-id', getSessionId: () => 'test-session-id',
getTelemetryLogPromptsEnabled: () => true, getTelemetryLogPromptsEnabled: () => true,
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
@ -34,25 +41,20 @@ const mockConfig = {
authType: 'oauth-personal', authType: 'oauth-personal',
model: 'test-model', model: 'test-model',
}), }),
getModel: vi.fn().mockReturnValue('gemini-pro'),
setModel: vi.fn(), setModel: vi.fn(),
flashFallbackHandler: undefined, flashFallbackHandler: undefined,
} as unknown as Config; } as unknown as Config;
describe('GeminiChat', () => {
let chat: GeminiChat;
const model = 'gemini-pro';
const config: GenerateContentConfig = {};
beforeEach(() => {
vi.clearAllMocks();
// Disable 429 simulation for tests // Disable 429 simulation for tests
setSimulate429(false); setSimulate429(false);
// Reset history for each test by creating a new instance // Reset history for each test by creating a new instance
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []); chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
}); });
afterEach(() => { afterEach(() => {
vi.restoreAllMocks(); vi.restoreAllMocks();
vi.resetAllMocks();
}); });
describe('sendMessage', () => { describe('sendMessage', () => {
@ -203,7 +205,7 @@ describe('GeminiChat', () => {
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
// Reset and set up a more realistic scenario for merging with existing history // Reset and set up a more realistic scenario for merging with existing history
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []); chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
const firstUserInput: Content = { const firstUserInput: Content = {
role: 'user', role: 'user',
parts: [{ text: 'First user input' }], parts: [{ text: 'First user input' }],
@ -246,7 +248,7 @@ describe('GeminiChat', () => {
role: 'model', role: 'model',
parts: [{ text: 'Initial model answer.' }], parts: [{ text: 'Initial model answer.' }],
}; };
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, [ chat = new GeminiChat(mockConfig, mockModelsModule, config, [
initialUser, initialUser,
initialModel, initialModel,
]); ]);

View File

@ -138,7 +138,6 @@ export class GeminiChat {
constructor( constructor(
private readonly config: Config, private readonly config: Config,
private readonly contentGenerator: ContentGenerator, private readonly contentGenerator: ContentGenerator,
private readonly model: string,
private readonly generationConfig: GenerateContentConfig = {}, private readonly generationConfig: GenerateContentConfig = {},
private history: Content[] = [], private history: Content[] = [],
) { ) {
@ -168,7 +167,12 @@ export class GeminiChat {
): Promise<void> { ): Promise<void> {
logApiResponse( logApiResponse(
this.config, this.config,
new ApiResponseEvent(this.model, durationMs, usageMetadata, responseText), new ApiResponseEvent(
this.config.getModel(),
durationMs,
usageMetadata,
responseText,
),
); );
} }
@ -178,7 +182,12 @@ export class GeminiChat {
logApiError( logApiError(
this.config, this.config,
new ApiErrorEvent(this.model, errorMessage, durationMs, errorType), new ApiErrorEvent(
this.config.getModel(),
errorMessage,
durationMs,
errorType,
),
); );
} }
@ -192,7 +201,7 @@ export class GeminiChat {
return null; return null;
} }
const currentModel = this.model; const currentModel = this.config.getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model // Don't fallback if already using Flash model
@ -244,7 +253,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message); const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent); const requestContents = this.getHistory(true).concat(userContent);
this._logApiRequest(requestContents, this.model); this._logApiRequest(requestContents, this.config.getModel());
const startTime = Date.now(); const startTime = Date.now();
let response: GenerateContentResponse; let response: GenerateContentResponse;
@ -252,12 +261,23 @@ export class GeminiChat {
try { try {
const apiCall = () => const apiCall = () =>
this.contentGenerator.generateContent({ this.contentGenerator.generateContent({
model: this.model, model: this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL,
contents: requestContents, contents: requestContents,
config: { ...this.generationConfig, ...params.config }, config: { ...this.generationConfig, ...params.config },
}); });
response = await retryWithBackoff(apiCall); response = await retryWithBackoff(apiCall, {
shouldRetry: (error: Error) => {
if (error && error.message) {
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
}
return false;
},
onPersistent429: async (authType?: string) =>
await this.handleFlashFallback(authType),
authType: this.config.getContentGeneratorConfig()?.authType,
});
const durationMs = Date.now() - startTime; const durationMs = Date.now() - startTime;
await this._logApiResponse( await this._logApiResponse(
durationMs, durationMs,
@ -326,14 +346,14 @@ export class GeminiChat {
await this.sendPromise; await this.sendPromise;
const userContent = createUserContent(params.message); const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent); const requestContents = this.getHistory(true).concat(userContent);
this._logApiRequest(requestContents, this.model); this._logApiRequest(requestContents, this.config.getModel());
const startTime = Date.now(); const startTime = Date.now();
try { try {
const apiCall = () => const apiCall = () =>
this.contentGenerator.generateContentStream({ this.contentGenerator.generateContentStream({
model: this.model, model: this.config.getModel(),
contents: requestContents, contents: requestContents,
config: { ...this.generationConfig, ...params.config }, config: { ...this.generationConfig, ...params.config },
}); });

View File

@ -71,7 +71,6 @@ describe('checkNextSpeaker', () => {
chatInstance = new GeminiChat( chatInstance = new GeminiChat(
mockConfigInstance, mockConfigInstance,
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
'gemini-pro', // model name
{}, {},
[], // initial history [], // initial history
); );

View File

@ -272,7 +272,7 @@ describe('retryWithBackoff', () => {
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal'); expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
// Should retry again after fallback // Should retry again after fallback
expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback
}); });
it('should NOT trigger fallback for API key users', async () => { it('should NOT trigger fallback for API key users', async () => {

View File

@ -67,9 +67,9 @@ export async function retryWithBackoff<T>(
maxAttempts, maxAttempts,
initialDelayMs, initialDelayMs,
maxDelayMs, maxDelayMs,
shouldRetry,
onPersistent429, onPersistent429,
authType, authType,
shouldRetry,
} = { } = {
...DEFAULT_RETRY_OPTIONS, ...DEFAULT_RETRY_OPTIONS,
...options, ...options,
@ -93,8 +93,6 @@ export async function retryWithBackoff<T>(
consecutive429Count = 0; consecutive429Count = 0;
} }
// Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
// If we have persistent 429s and a fallback callback for OAuth // If we have persistent 429s and a fallback callback for OAuth
if ( if (
consecutive429Count >= 2 && consecutive429Count >= 2 &&
@ -108,6 +106,7 @@ export async function retryWithBackoff<T>(
attempt = 0; attempt = 0;
consecutive429Count = 0; consecutive429Count = 0;
currentDelay = initialDelayMs; currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue; continue;
} }
} catch (fallbackError) { } catch (fallbackError) {
@ -115,6 +114,9 @@ export async function retryWithBackoff<T>(
console.warn('Fallback to Flash model failed:', fallbackError); console.warn('Fallback to Flash model failed:', fallbackError);
} }
} }
// Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
throw error; throw error;
} }