This commit is contained in:
Bryan Morgan 2025-06-25 21:45:38 -04:00 committed by GitHub
parent b6b9923dc3
commit bb797ded7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 72 additions and 50 deletions

View File

@ -196,7 +196,6 @@ export class GeminiClient {
return new GeminiChat(
this.config,
this.getContentGenerator(),
this.model,
{
systemInstruction,
...generateContentConfigWithThinking,

View File

@ -25,34 +25,36 @@ const mockModelsModule = {
batchEmbedContents: vi.fn(),
} as unknown as Models;
const mockConfig = {
getSessionId: () => 'test-session-id',
getTelemetryLogPromptsEnabled: () => true,
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
authType: 'oauth-personal',
model: 'test-model',
}),
setModel: vi.fn(),
flashFallbackHandler: undefined,
} as unknown as Config;
describe('GeminiChat', () => {
let chat: GeminiChat;
const model = 'gemini-pro';
let mockConfig: Config;
const config: GenerateContentConfig = {};
beforeEach(() => {
vi.clearAllMocks();
mockConfig = {
getSessionId: () => 'test-session-id',
getTelemetryLogPromptsEnabled: () => true,
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
authType: 'oauth-personal',
model: 'test-model',
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
setModel: vi.fn(),
flashFallbackHandler: undefined,
} as unknown as Config;
// Disable 429 simulation for tests
setSimulate429(false);
// Reset history for each test by creating a new instance
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
});
afterEach(() => {
vi.restoreAllMocks();
vi.resetAllMocks();
});
describe('sendMessage', () => {
@ -203,7 +205,7 @@ describe('GeminiChat', () => {
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
// Reset and set up a more realistic scenario for merging with existing history
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
const firstUserInput: Content = {
role: 'user',
parts: [{ text: 'First user input' }],
@ -246,7 +248,7 @@ describe('GeminiChat', () => {
role: 'model',
parts: [{ text: 'Initial model answer.' }],
};
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, [
chat = new GeminiChat(mockConfig, mockModelsModule, config, [
initialUser,
initialModel,
]);

View File

@ -138,7 +138,6 @@ export class GeminiChat {
constructor(
private readonly config: Config,
private readonly contentGenerator: ContentGenerator,
private readonly model: string,
private readonly generationConfig: GenerateContentConfig = {},
private history: Content[] = [],
) {
@ -168,7 +167,12 @@ export class GeminiChat {
): Promise<void> {
logApiResponse(
this.config,
new ApiResponseEvent(this.model, durationMs, usageMetadata, responseText),
new ApiResponseEvent(
this.config.getModel(),
durationMs,
usageMetadata,
responseText,
),
);
}
@ -178,7 +182,12 @@ export class GeminiChat {
logApiError(
this.config,
new ApiErrorEvent(this.model, errorMessage, durationMs, errorType),
new ApiErrorEvent(
this.config.getModel(),
errorMessage,
durationMs,
errorType,
),
);
}
@ -192,7 +201,7 @@ export class GeminiChat {
return null;
}
const currentModel = this.model;
const currentModel = this.config.getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
@ -244,7 +253,7 @@ export class GeminiChat {
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
this._logApiRequest(requestContents, this.model);
this._logApiRequest(requestContents, this.config.getModel());
const startTime = Date.now();
let response: GenerateContentResponse;
@ -252,12 +261,23 @@ export class GeminiChat {
try {
const apiCall = () =>
this.contentGenerator.generateContent({
model: this.model,
model: this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL,
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
});
response = await retryWithBackoff(apiCall);
response = await retryWithBackoff(apiCall, {
shouldRetry: (error: Error) => {
if (error && error.message) {
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
}
return false;
},
onPersistent429: async (authType?: string) =>
await this.handleFlashFallback(authType),
authType: this.config.getContentGeneratorConfig()?.authType,
});
const durationMs = Date.now() - startTime;
await this._logApiResponse(
durationMs,
@ -326,14 +346,14 @@ export class GeminiChat {
await this.sendPromise;
const userContent = createUserContent(params.message);
const requestContents = this.getHistory(true).concat(userContent);
this._logApiRequest(requestContents, this.model);
this._logApiRequest(requestContents, this.config.getModel());
const startTime = Date.now();
try {
const apiCall = () =>
this.contentGenerator.generateContentStream({
model: this.model,
model: this.config.getModel(),
contents: requestContents,
config: { ...this.generationConfig, ...params.config },
});

View File

@ -71,7 +71,6 @@ describe('checkNextSpeaker', () => {
chatInstance = new GeminiChat(
mockConfigInstance,
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
'gemini-pro', // model name
{},
[], // initial history
);

View File

@ -272,7 +272,7 @@ describe('retryWithBackoff', () => {
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
// Should retry again after fallback
expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback
expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback
});
it('should NOT trigger fallback for API key users', async () => {

View File

@ -67,9 +67,9 @@ export async function retryWithBackoff<T>(
maxAttempts,
initialDelayMs,
maxDelayMs,
shouldRetry,
onPersistent429,
authType,
shouldRetry,
} = {
...DEFAULT_RETRY_OPTIONS,
...options,
@ -93,28 +93,30 @@ export async function retryWithBackoff<T>(
consecutive429Count = 0;
}
// If we have persistent 429s and a fallback callback for OAuth
if (
consecutive429Count >= 2 &&
onPersistent429 &&
authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL
) {
try {
const fallbackModel = await onPersistent429(authType);
if (fallbackModel) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
// With the model updated, we continue to the next attempt
continue;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Check if we've exhausted retries or shouldn't retry
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
// If we have persistent 429s and a fallback callback for OAuth
if (
consecutive429Count >= 2 &&
onPersistent429 &&
authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL
) {
try {
const fallbackModel = await onPersistent429(authType);
if (fallbackModel) {
// Reset attempt counter and try with new model
attempt = 0;
consecutive429Count = 0;
currentDelay = initialDelayMs;
continue;
}
} catch (fallbackError) {
// If fallback fails, continue with original error
console.warn('Fallback to Flash model failed:', fallbackError);
}
}
throw error;
}