429 fix (#1668)
This commit is contained in:
parent
b6b9923dc3
commit
bb797ded7d
|
@ -196,7 +196,6 @@ export class GeminiClient {
|
||||||
return new GeminiChat(
|
return new GeminiChat(
|
||||||
this.config,
|
this.config,
|
||||||
this.getContentGenerator(),
|
this.getContentGenerator(),
|
||||||
this.model,
|
|
||||||
{
|
{
|
||||||
systemInstruction,
|
systemInstruction,
|
||||||
...generateContentConfigWithThinking,
|
...generateContentConfigWithThinking,
|
||||||
|
|
|
@ -25,34 +25,36 @@ const mockModelsModule = {
|
||||||
batchEmbedContents: vi.fn(),
|
batchEmbedContents: vi.fn(),
|
||||||
} as unknown as Models;
|
} as unknown as Models;
|
||||||
|
|
||||||
const mockConfig = {
|
|
||||||
getSessionId: () => 'test-session-id',
|
|
||||||
getTelemetryLogPromptsEnabled: () => true,
|
|
||||||
getUsageStatisticsEnabled: () => true,
|
|
||||||
getDebugMode: () => false,
|
|
||||||
getContentGeneratorConfig: () => ({
|
|
||||||
authType: 'oauth-personal',
|
|
||||||
model: 'test-model',
|
|
||||||
}),
|
|
||||||
setModel: vi.fn(),
|
|
||||||
flashFallbackHandler: undefined,
|
|
||||||
} as unknown as Config;
|
|
||||||
|
|
||||||
describe('GeminiChat', () => {
|
describe('GeminiChat', () => {
|
||||||
let chat: GeminiChat;
|
let chat: GeminiChat;
|
||||||
const model = 'gemini-pro';
|
let mockConfig: Config;
|
||||||
const config: GenerateContentConfig = {};
|
const config: GenerateContentConfig = {};
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
|
mockConfig = {
|
||||||
|
getSessionId: () => 'test-session-id',
|
||||||
|
getTelemetryLogPromptsEnabled: () => true,
|
||||||
|
getUsageStatisticsEnabled: () => true,
|
||||||
|
getDebugMode: () => false,
|
||||||
|
getContentGeneratorConfig: () => ({
|
||||||
|
authType: 'oauth-personal',
|
||||||
|
model: 'test-model',
|
||||||
|
}),
|
||||||
|
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||||
|
setModel: vi.fn(),
|
||||||
|
flashFallbackHandler: undefined,
|
||||||
|
} as unknown as Config;
|
||||||
|
|
||||||
// Disable 429 simulation for tests
|
// Disable 429 simulation for tests
|
||||||
setSimulate429(false);
|
setSimulate429(false);
|
||||||
// Reset history for each test by creating a new instance
|
// Reset history for each test by creating a new instance
|
||||||
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
|
chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
vi.restoreAllMocks();
|
vi.restoreAllMocks();
|
||||||
|
vi.resetAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('sendMessage', () => {
|
describe('sendMessage', () => {
|
||||||
|
@ -203,7 +205,7 @@ describe('GeminiChat', () => {
|
||||||
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
|
chat.recordHistory(userInput, newModelOutput); // userInput here is for the *next* turn, but history is already primed
|
||||||
|
|
||||||
// Reset and set up a more realistic scenario for merging with existing history
|
// Reset and set up a more realistic scenario for merging with existing history
|
||||||
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, []);
|
chat = new GeminiChat(mockConfig, mockModelsModule, config, []);
|
||||||
const firstUserInput: Content = {
|
const firstUserInput: Content = {
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: [{ text: 'First user input' }],
|
parts: [{ text: 'First user input' }],
|
||||||
|
@ -246,7 +248,7 @@ describe('GeminiChat', () => {
|
||||||
role: 'model',
|
role: 'model',
|
||||||
parts: [{ text: 'Initial model answer.' }],
|
parts: [{ text: 'Initial model answer.' }],
|
||||||
};
|
};
|
||||||
chat = new GeminiChat(mockConfig, mockModelsModule, model, config, [
|
chat = new GeminiChat(mockConfig, mockModelsModule, config, [
|
||||||
initialUser,
|
initialUser,
|
||||||
initialModel,
|
initialModel,
|
||||||
]);
|
]);
|
||||||
|
|
|
@ -138,7 +138,6 @@ export class GeminiChat {
|
||||||
constructor(
|
constructor(
|
||||||
private readonly config: Config,
|
private readonly config: Config,
|
||||||
private readonly contentGenerator: ContentGenerator,
|
private readonly contentGenerator: ContentGenerator,
|
||||||
private readonly model: string,
|
|
||||||
private readonly generationConfig: GenerateContentConfig = {},
|
private readonly generationConfig: GenerateContentConfig = {},
|
||||||
private history: Content[] = [],
|
private history: Content[] = [],
|
||||||
) {
|
) {
|
||||||
|
@ -168,7 +167,12 @@ export class GeminiChat {
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
logApiResponse(
|
logApiResponse(
|
||||||
this.config,
|
this.config,
|
||||||
new ApiResponseEvent(this.model, durationMs, usageMetadata, responseText),
|
new ApiResponseEvent(
|
||||||
|
this.config.getModel(),
|
||||||
|
durationMs,
|
||||||
|
usageMetadata,
|
||||||
|
responseText,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,7 +182,12 @@ export class GeminiChat {
|
||||||
|
|
||||||
logApiError(
|
logApiError(
|
||||||
this.config,
|
this.config,
|
||||||
new ApiErrorEvent(this.model, errorMessage, durationMs, errorType),
|
new ApiErrorEvent(
|
||||||
|
this.config.getModel(),
|
||||||
|
errorMessage,
|
||||||
|
durationMs,
|
||||||
|
errorType,
|
||||||
|
),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,7 +201,7 @@ export class GeminiChat {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
const currentModel = this.model;
|
const currentModel = this.config.getModel();
|
||||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||||
|
|
||||||
// Don't fallback if already using Flash model
|
// Don't fallback if already using Flash model
|
||||||
|
@ -244,7 +253,7 @@ export class GeminiChat {
|
||||||
const userContent = createUserContent(params.message);
|
const userContent = createUserContent(params.message);
|
||||||
const requestContents = this.getHistory(true).concat(userContent);
|
const requestContents = this.getHistory(true).concat(userContent);
|
||||||
|
|
||||||
this._logApiRequest(requestContents, this.model);
|
this._logApiRequest(requestContents, this.config.getModel());
|
||||||
|
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
let response: GenerateContentResponse;
|
let response: GenerateContentResponse;
|
||||||
|
@ -252,12 +261,23 @@ export class GeminiChat {
|
||||||
try {
|
try {
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.contentGenerator.generateContent({
|
this.contentGenerator.generateContent({
|
||||||
model: this.model,
|
model: this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
contents: requestContents,
|
contents: requestContents,
|
||||||
config: { ...this.generationConfig, ...params.config },
|
config: { ...this.generationConfig, ...params.config },
|
||||||
});
|
});
|
||||||
|
|
||||||
response = await retryWithBackoff(apiCall);
|
response = await retryWithBackoff(apiCall, {
|
||||||
|
shouldRetry: (error: Error) => {
|
||||||
|
if (error && error.message) {
|
||||||
|
if (error.message.includes('429')) return true;
|
||||||
|
if (error.message.match(/5\d{2}/)) return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
onPersistent429: async (authType?: string) =>
|
||||||
|
await this.handleFlashFallback(authType),
|
||||||
|
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||||
|
});
|
||||||
const durationMs = Date.now() - startTime;
|
const durationMs = Date.now() - startTime;
|
||||||
await this._logApiResponse(
|
await this._logApiResponse(
|
||||||
durationMs,
|
durationMs,
|
||||||
|
@ -326,14 +346,14 @@ export class GeminiChat {
|
||||||
await this.sendPromise;
|
await this.sendPromise;
|
||||||
const userContent = createUserContent(params.message);
|
const userContent = createUserContent(params.message);
|
||||||
const requestContents = this.getHistory(true).concat(userContent);
|
const requestContents = this.getHistory(true).concat(userContent);
|
||||||
this._logApiRequest(requestContents, this.model);
|
this._logApiRequest(requestContents, this.config.getModel());
|
||||||
|
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const apiCall = () =>
|
const apiCall = () =>
|
||||||
this.contentGenerator.generateContentStream({
|
this.contentGenerator.generateContentStream({
|
||||||
model: this.model,
|
model: this.config.getModel(),
|
||||||
contents: requestContents,
|
contents: requestContents,
|
||||||
config: { ...this.generationConfig, ...params.config },
|
config: { ...this.generationConfig, ...params.config },
|
||||||
});
|
});
|
||||||
|
|
|
@ -71,7 +71,6 @@ describe('checkNextSpeaker', () => {
|
||||||
chatInstance = new GeminiChat(
|
chatInstance = new GeminiChat(
|
||||||
mockConfigInstance,
|
mockConfigInstance,
|
||||||
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
||||||
'gemini-pro', // model name
|
|
||||||
{},
|
{},
|
||||||
[], // initial history
|
[], // initial history
|
||||||
);
|
);
|
||||||
|
|
|
@ -272,7 +272,7 @@ describe('retryWithBackoff', () => {
|
||||||
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
expect(fallbackCallback).toHaveBeenCalledWith('oauth-personal');
|
||||||
|
|
||||||
// Should retry again after fallback
|
// Should retry again after fallback
|
||||||
expect(mockFn).toHaveBeenCalledTimes(4); // 3 initial attempts + 1 after fallback
|
expect(mockFn).toHaveBeenCalledTimes(3); // 2 initial attempts + 1 after fallback
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should NOT trigger fallback for API key users', async () => {
|
it('should NOT trigger fallback for API key users', async () => {
|
||||||
|
|
|
@ -67,9 +67,9 @@ export async function retryWithBackoff<T>(
|
||||||
maxAttempts,
|
maxAttempts,
|
||||||
initialDelayMs,
|
initialDelayMs,
|
||||||
maxDelayMs,
|
maxDelayMs,
|
||||||
shouldRetry,
|
|
||||||
onPersistent429,
|
onPersistent429,
|
||||||
authType,
|
authType,
|
||||||
|
shouldRetry,
|
||||||
} = {
|
} = {
|
||||||
...DEFAULT_RETRY_OPTIONS,
|
...DEFAULT_RETRY_OPTIONS,
|
||||||
...options,
|
...options,
|
||||||
|
@ -93,28 +93,30 @@ export async function retryWithBackoff<T>(
|
||||||
consecutive429Count = 0;
|
consecutive429Count = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we have persistent 429s and a fallback callback for OAuth
|
||||||
|
if (
|
||||||
|
consecutive429Count >= 2 &&
|
||||||
|
onPersistent429 &&
|
||||||
|
authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL
|
||||||
|
) {
|
||||||
|
try {
|
||||||
|
const fallbackModel = await onPersistent429(authType);
|
||||||
|
if (fallbackModel) {
|
||||||
|
// Reset attempt counter and try with new model
|
||||||
|
attempt = 0;
|
||||||
|
consecutive429Count = 0;
|
||||||
|
currentDelay = initialDelayMs;
|
||||||
|
// With the model updated, we continue to the next attempt
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} catch (fallbackError) {
|
||||||
|
// If fallback fails, continue with original error
|
||||||
|
console.warn('Fallback to Flash model failed:', fallbackError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we've exhausted retries or shouldn't retry
|
// Check if we've exhausted retries or shouldn't retry
|
||||||
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
|
if (attempt >= maxAttempts || !shouldRetry(error as Error)) {
|
||||||
// If we have persistent 429s and a fallback callback for OAuth
|
|
||||||
if (
|
|
||||||
consecutive429Count >= 2 &&
|
|
||||||
onPersistent429 &&
|
|
||||||
authType === AuthType.LOGIN_WITH_GOOGLE_PERSONAL
|
|
||||||
) {
|
|
||||||
try {
|
|
||||||
const fallbackModel = await onPersistent429(authType);
|
|
||||||
if (fallbackModel) {
|
|
||||||
// Reset attempt counter and try with new model
|
|
||||||
attempt = 0;
|
|
||||||
consecutive429Count = 0;
|
|
||||||
currentDelay = initialDelayMs;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} catch (fallbackError) {
|
|
||||||
// If fallback fails, continue with original error
|
|
||||||
console.warn('Fallback to Flash model failed:', fallbackError);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw error;
|
throw error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue