From 8c46108a852128d1d0792c149746631d83fc58cf Mon Sep 17 00:00:00 2001 From: "N. Taylor Mullen" Date: Fri, 30 May 2025 10:57:00 -0700 Subject: [PATCH] feat: Implement retry with backoff for API calls (#613) --- packages/server/src/core/client.ts | 25 ++- packages/server/src/core/geminiChat.ts | 48 +++-- packages/server/src/tools/web-fetch.ts | 26 +-- packages/server/src/tools/web-search.ts | 18 +- packages/server/src/utils/retry.test.ts | 238 ++++++++++++++++++++++++ packages/server/src/utils/retry.ts | 227 ++++++++++++++++++++++ 6 files changed, 542 insertions(+), 40 deletions(-) create mode 100644 packages/server/src/utils/retry.test.ts create mode 100644 packages/server/src/utils/retry.ts diff --git a/packages/server/src/core/client.ts b/packages/server/src/core/client.ts index 69b815ab..9006c675 100644 --- a/packages/server/src/core/client.ts +++ b/packages/server/src/core/client.ts @@ -23,6 +23,7 @@ import { getResponseText } from '../utils/generateContentResponseUtilities.js'; import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js'; import { reportError } from '../utils/errorReporting.js'; import { GeminiChat } from './geminiChat.js'; +import { retryWithBackoff } from '../utils/retry.js'; export class GeminiClient { private client: GoogleGenAI; @@ -194,16 +195,20 @@ export class GeminiClient { ...config, }; - const result = await this.client.models.generateContent({ - model, - config: { - ...requestConfig, - systemInstruction, - responseSchema: schema, - responseMimeType: 'application/json', - }, - contents, - }); + const apiCall = () => + this.client.models.generateContent({ + model, + config: { + ...requestConfig, + systemInstruction, + responseSchema: schema, + responseMimeType: 'application/json', + }, + contents, + }); + + const result = await retryWithBackoff(apiCall); + const text = getResponseText(result); if (!text) { const error = new Error( diff --git a/packages/server/src/core/geminiChat.ts b/packages/server/src/core/geminiChat.ts index 877d0825..b34b6f35 100644 --- a/packages/server/src/core/geminiChat.ts +++ b/packages/server/src/core/geminiChat.ts @@ -16,6 +16,7 @@ import { GoogleGenAI, createUserContent, } from '@google/genai'; +import { retryWithBackoff } from '../utils/retry.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; /** @@ -152,11 +153,16 @@ export class GeminiChat { ): Promise { await this.sendPromise; const userContent = createUserContent(params.message); - const responsePromise = this.modelsModule.generateContent({ - model: this.model, - contents: this.getHistory(true).concat(userContent), - config: { ...this.config, ...params.config }, - }); + + const apiCall = () => + this.modelsModule.generateContent({ + model: this.model, + contents: this.getHistory(true).concat(userContent), + config: { ...this.config, ...params.config }, + }); + + const responsePromise = retryWithBackoff(apiCall); + this.sendPromise = (async () => { const response = await responsePromise; const outputContent = response.candidates?.[0]?.content; @@ -216,19 +222,37 @@ export class GeminiChat { ): Promise> { await this.sendPromise; const userContent = createUserContent(params.message); - const streamResponse = this.modelsModule.generateContentStream({ - model: this.model, - contents: this.getHistory(true).concat(userContent), - config: { ...this.config, ...params.config }, + + const apiCall = () => + this.modelsModule.generateContentStream({ + model: this.model, + contents: this.getHistory(true).concat(userContent), + config: { ...this.config, ...params.config }, + }); + + // Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries + // for transient issues internally before yielding the async generator, this retry will re-initiate + // the stream. For simple 429/500 errors on initial call, this is fine. + // If errors occur mid-stream, this setup won't resume the stream; it will restart it. + const streamResponse = await retryWithBackoff(apiCall, { + shouldRetry: (error: Error) => { + // Check error messages for status codes, or specific error names if known + if (error && error.message) { + if (error.message.includes('429')) return true; + if (error.message.match(/5\d{2}/)) return true; + } + return false; // Don't retry other errors by default + }, }); + // Resolve the internal tracking of send completion promise - `sendPromise` // for both success and failure response. The actual failure is still // propagated by the `await streamResponse`. - this.sendPromise = streamResponse + this.sendPromise = Promise.resolve(streamResponse) .then(() => undefined) .catch(() => undefined); - const response = await streamResponse; - const result = this.processStreamResponse(response, userContent); + + const result = this.processStreamResponse(streamResponse, userContent); return result; } diff --git a/packages/server/src/tools/web-fetch.ts b/packages/server/src/tools/web-fetch.ts index 7a8a1515..24617902 100644 --- a/packages/server/src/tools/web-fetch.ts +++ b/packages/server/src/tools/web-fetch.ts @@ -10,6 +10,7 @@ import { BaseTool, ToolResult } from './tools.js'; import { getErrorMessage } from '../utils/errors.js'; import { Config } from '../config/config.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js'; +import { retryWithBackoff } from '../utils/retry.js'; // Interfaces for grounding metadata (similar to web-search.ts) interface GroundingChunkWeb { @@ -121,18 +122,21 @@ export class WebFetchTool extends BaseTool { const userPrompt = params.prompt; try { - const response = await this.ai.models.generateContent({ - model: this.modelName, - contents: [ - { - role: 'user', - parts: [{ text: userPrompt }], + const apiCall = () => + this.ai.models.generateContent({ + model: this.modelName, + contents: [ + { + role: 'user', + parts: [{ text: userPrompt }], + }, + ], + config: { + tools: [{ urlContext: {} }], }, - ], - config: { - tools: [{ urlContext: {} }], - }, - }); + }); + + const response = await retryWithBackoff(apiCall); console.debug( `[WebFetchTool] Full response for prompt "${userPrompt.substring(0, 50)}...":`, diff --git a/packages/server/src/tools/web-search.ts b/packages/server/src/tools/web-search.ts index b690146d..ed2f341f 100644 --- a/packages/server/src/tools/web-search.ts +++ b/packages/server/src/tools/web-search.ts @@ -11,6 +11,7 @@ import { SchemaValidator } from '../utils/schemaValidator.js'; import { getErrorMessage } from '../utils/errors.js'; import { Config } from '../config/config.js'; import { getResponseText } from '../utils/generateContentResponseUtilities.js'; +import { retryWithBackoff } from '../utils/retry.js'; interface GroundingChunkWeb { uri?: string; @@ -121,13 +122,16 @@ export class WebSearchTool extends BaseTool< } try { - const response = await this.ai.models.generateContent({ - model: this.modelName, - contents: [{ role: 'user', parts: [{ text: params.query }] }], - config: { - tools: [{ googleSearch: {} }], - }, - }); + const apiCall = () => + this.ai.models.generateContent({ + model: this.modelName, + contents: [{ role: 'user', parts: [{ text: params.query }] }], + config: { + tools: [{ googleSearch: {} }], + }, + }); + + const response = await retryWithBackoff(apiCall); const responseText = getResponseText(response); const groundingMetadata = response.candidates?.[0]?.groundingMetadata; diff --git a/packages/server/src/utils/retry.test.ts b/packages/server/src/utils/retry.test.ts new file mode 100644 index 00000000..ea344d60 --- /dev/null +++ b/packages/server/src/utils/retry.test.ts @@ -0,0 +1,238 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { retryWithBackoff } from './retry.js'; + +// Define an interface for the error with a status property +interface HttpError extends Error { + status?: number; +} + +// Helper to create a mock function that fails a certain number of times +const createFailingFunction = ( + failures: number, + successValue: string = 'success', +) => { + let attempts = 0; + return vi.fn(async () => { + attempts++; + if (attempts <= failures) { + // Simulate a retryable error + const error: HttpError = new Error(`Simulated error attempt ${attempts}`); + error.status = 500; // Simulate a server error + throw error; + } + return successValue; + }); +}; + +// Custom error for testing non-retryable conditions +class NonRetryableError extends Error { + constructor(message: string) { + super(message); + this.name = 'NonRetryableError'; + } +} + +describe('retryWithBackoff', () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should return the result on the first attempt if successful', async () => { + const mockFn = createFailingFunction(0); + const result = await retryWithBackoff(mockFn); + expect(result).toBe('success'); + expect(mockFn).toHaveBeenCalledTimes(1); + }); + + it('should retry and succeed if failures are within maxAttempts', async () => { + const mockFn = createFailingFunction(2); + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 10, + }); + + await vi.runAllTimersAsync(); // Ensure all delays and retries complete + + const result = await promise; + expect(result).toBe('success'); + expect(mockFn).toHaveBeenCalledTimes(3); + }); + + it('should throw an error if all attempts fail', async () => { + const mockFn = createFailingFunction(3); + + // 1. Start the retryable operation, which returns a promise. + const promise = retryWithBackoff(mockFn, { + maxAttempts: 3, + initialDelayMs: 10, + }); + + // 2. IMPORTANT: Attach the rejection expectation to the promise *immediately*. + // This ensures a 'catch' handler is present before the promise can reject. + // The result is a new promise that resolves when the assertion is met. + const assertionPromise = expect(promise).rejects.toThrow( + 'Simulated error attempt 3', + ); + + // 3. Now, advance the timers. This will trigger the retries and the + // eventual rejection. The handler attached in step 2 will catch it. + await vi.runAllTimersAsync(); + + // 4. Await the assertion promise itself to ensure the test was successful. + await assertionPromise; + + // 5. Finally, assert the number of calls. + expect(mockFn).toHaveBeenCalledTimes(3); + }); + + it('should not retry if shouldRetry returns false', async () => { + const mockFn = vi.fn(async () => { + throw new NonRetryableError('Non-retryable error'); + }); + const shouldRetry = (error: Error) => !(error instanceof NonRetryableError); + + const promise = retryWithBackoff(mockFn, { + shouldRetry, + initialDelayMs: 10, + }); + + await expect(promise).rejects.toThrow('Non-retryable error'); + expect(mockFn).toHaveBeenCalledTimes(1); + }); + + it('should use default shouldRetry if not provided, retrying on 429', async () => { + const mockFn = vi.fn(async () => { + const error = new Error('Too Many Requests') as any; + error.status = 429; + throw error; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 10, + }); + + // Attach the rejection expectation *before* running timers + const assertionPromise = + expect(promise).rejects.toThrow('Too Many Requests'); + + // Run timers to trigger retries and eventual rejection + await vi.runAllTimersAsync(); + + // Await the assertion + await assertionPromise; + + expect(mockFn).toHaveBeenCalledTimes(2); + }); + + it('should use default shouldRetry if not provided, not retrying on 400', async () => { + const mockFn = vi.fn(async () => { + const error = new Error('Bad Request') as any; + error.status = 400; + throw error; + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 10, + }); + await expect(promise).rejects.toThrow('Bad Request'); + expect(mockFn).toHaveBeenCalledTimes(1); + }); + + it('should respect maxDelayMs', async () => { + const mockFn = createFailingFunction(3); + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 4, + initialDelayMs: 100, + maxDelayMs: 250, // Max delay is less than 100 * 2 * 2 = 400 + }); + + await vi.advanceTimersByTimeAsync(1000); // Advance well past all delays + await promise; + + const delays = setTimeoutSpy.mock.calls.map((call) => call[1] as number); + + // Delays should be around initial, initial*2, maxDelay (due to cap) + // Jitter makes exact assertion hard, so we check ranges / caps + expect(delays.length).toBe(3); + expect(delays[0]).toBeGreaterThanOrEqual(100 * 0.7); + expect(delays[0]).toBeLessThanOrEqual(100 * 1.3); + expect(delays[1]).toBeGreaterThanOrEqual(200 * 0.7); + expect(delays[1]).toBeLessThanOrEqual(200 * 1.3); + // The third delay should be capped by maxDelayMs (250ms), accounting for jitter + expect(delays[2]).toBeGreaterThanOrEqual(250 * 0.7); + expect(delays[2]).toBeLessThanOrEqual(250 * 1.3); + + setTimeoutSpy.mockRestore(); + }); + + it('should handle jitter correctly, ensuring varied delays', async () => { + let mockFn = createFailingFunction(5); + const setTimeoutSpy = vi.spyOn(global, 'setTimeout'); + + // Run retryWithBackoff multiple times to observe jitter + const runRetry = () => + retryWithBackoff(mockFn, { + maxAttempts: 2, // Only one retry, so one delay + initialDelayMs: 100, + maxDelayMs: 1000, + }); + + // We expect rejections as mockFn fails 5 times + const promise1 = runRetry(); + // Attach the rejection expectation *before* running timers + const assertionPromise1 = expect(promise1).rejects.toThrow(); + await vi.runAllTimersAsync(); // Advance for the delay in the first runRetry + await assertionPromise1; + + const firstDelaySet = setTimeoutSpy.mock.calls.map( + (call) => call[1] as number, + ); + setTimeoutSpy.mockClear(); // Clear calls for the next run + + // Reset mockFn to reset its internal attempt counter for the next run + mockFn = createFailingFunction(5); // Re-initialize with 5 failures + + const promise2 = runRetry(); + // Attach the rejection expectation *before* running timers + const assertionPromise2 = expect(promise2).rejects.toThrow(); + await vi.runAllTimersAsync(); // Advance for the delay in the second runRetry + await assertionPromise2; + + const secondDelaySet = setTimeoutSpy.mock.calls.map( + (call) => call[1] as number, + ); + + // Check that the delays are not exactly the same due to jitter + // This is a probabilistic test, but with +/-30% jitter, it's highly likely they differ. + if (firstDelaySet.length > 0 && secondDelaySet.length > 0) { + // Check the first delay of each set + expect(firstDelaySet[0]).not.toBe(secondDelaySet[0]); + } else { + // If somehow no delays were captured (e.g. test setup issue), fail explicitly + throw new Error('Delays were not captured for jitter test'); + } + + // Ensure delays are within the expected jitter range [70, 130] for initialDelayMs = 100 + [...firstDelaySet, ...secondDelaySet].forEach((d) => { + expect(d).toBeGreaterThanOrEqual(100 * 0.7); + expect(d).toBeLessThanOrEqual(100 * 1.3); + }); + + setTimeoutSpy.mockRestore(); + }); +}); diff --git a/packages/server/src/utils/retry.ts b/packages/server/src/utils/retry.ts new file mode 100644 index 00000000..1e7d5bcb --- /dev/null +++ b/packages/server/src/utils/retry.ts @@ -0,0 +1,227 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export interface RetryOptions { + maxAttempts: number; + initialDelayMs: number; + maxDelayMs: number; + shouldRetry: (error: Error) => boolean; +} + +const DEFAULT_RETRY_OPTIONS: RetryOptions = { + maxAttempts: 5, + initialDelayMs: 5000, + maxDelayMs: 30000, // 30 seconds + shouldRetry: defaultShouldRetry, +}; + +/** + * Default predicate function to determine if a retry should be attempted. + * Retries on 429 (Too Many Requests) and 5xx server errors. + * @param error The error object. + * @returns True if the error is a transient error, false otherwise. + */ +function defaultShouldRetry(error: Error | unknown): boolean { + // Check for common transient error status codes either in message or a status property + if (error && typeof (error as { status?: number }).status === 'number') { + const status = (error as { status: number }).status; + if (status === 429 || (status >= 500 && status < 600)) { + return true; + } + } + if (error instanceof Error && error.message) { + if (error.message.includes('429')) return true; + if (error.message.match(/5\d{2}/)) return true; + } + return false; +} + +/** + * Delays execution for a specified number of milliseconds. + * @param ms The number of milliseconds to delay. + * @returns A promise that resolves after the delay. + */ +function delay(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +/** + * Retries a function with exponential backoff and jitter. + * @param fn The asynchronous function to retry. + * @param options Optional retry configuration. + * @returns A promise that resolves with the result of the function if successful. + * @throws The last error encountered if all attempts fail. + */ +export async function retryWithBackoff( + fn: () => Promise, + options?: Partial, +): Promise { + const { maxAttempts, initialDelayMs, maxDelayMs, shouldRetry } = { + ...DEFAULT_RETRY_OPTIONS, + ...options, + }; + + let attempt = 0; + let currentDelay = initialDelayMs; + + while (attempt < maxAttempts) { + attempt++; + try { + return await fn(); + } catch (error) { + if (attempt >= maxAttempts || !shouldRetry(error as Error)) { + throw error; + } + + const { delayDurationMs, errorStatus } = getDelayDurationAndStatus(error); + + if (delayDurationMs > 0) { + // Respect Retry-After header if present and parsed + console.warn( + `Attempt ${attempt} failed with status ${errorStatus ?? 'unknown'}. Retrying after explicit delay of ${delayDurationMs}ms...`, + error, + ); + await delay(delayDurationMs); + // Reset currentDelay for next potential non-429 error, or if Retry-After is not present next time + currentDelay = initialDelayMs; + } else { + // Fallback to exponential backoff with jitter + logRetryAttempt(attempt, error, errorStatus); + // Add jitter: +/- 30% of currentDelay + const jitter = currentDelay * 0.3 * (Math.random() * 2 - 1); + const delayWithJitter = Math.max(0, currentDelay + jitter); + await delay(delayWithJitter); + currentDelay = Math.min(maxDelayMs, currentDelay * 2); + } + } + } + // This line should theoretically be unreachable due to the throw in the catch block. + // Added for type safety and to satisfy the compiler that a promise is always returned. + throw new Error('Retry attempts exhausted'); +} + +/** + * Extracts the HTTP status code from an error object. + * @param error The error object. + * @returns The HTTP status code, or undefined if not found. + */ +function getErrorStatus(error: unknown): number | undefined { + if (typeof error === 'object' && error !== null) { + if ('status' in error && typeof error.status === 'number') { + return error.status; + } + // Check for error.response.status (common in axios errors) + if ( + 'response' in error && + typeof (error as { response?: unknown }).response === 'object' && + (error as { response?: unknown }).response !== null + ) { + const response = ( + error as { response: { status?: unknown; headers?: unknown } } + ).response; + if ('status' in response && typeof response.status === 'number') { + return response.status; + } + } + } + return undefined; +} + +/** + * Extracts the Retry-After delay from an error object's headers. + * @param error The error object. + * @returns The delay in milliseconds, or 0 if not found or invalid. + */ +function getRetryAfterDelayMs(error: unknown): number { + if (typeof error === 'object' && error !== null) { + // Check for error.response.headers (common in axios errors) + if ( + 'response' in error && + typeof (error as { response?: unknown }).response === 'object' && + (error as { response?: unknown }).response !== null + ) { + const response = (error as { response: { headers?: unknown } }).response; + if ( + 'headers' in response && + typeof response.headers === 'object' && + response.headers !== null + ) { + const headers = response.headers as { 'retry-after'?: unknown }; + const retryAfterHeader = headers['retry-after']; + if (typeof retryAfterHeader === 'string') { + const retryAfterSeconds = parseInt(retryAfterHeader, 10); + if (!isNaN(retryAfterSeconds)) { + return retryAfterSeconds * 1000; + } + // It might be an HTTP date + const retryAfterDate = new Date(retryAfterHeader); + if (!isNaN(retryAfterDate.getTime())) { + return Math.max(0, retryAfterDate.getTime() - Date.now()); + } + } + } + } + } + return 0; +} + +/** + * Determines the delay duration based on the error, prioritizing Retry-After header. + * @param error The error object. + * @returns An object containing the delay duration in milliseconds and the error status. + */ +function getDelayDurationAndStatus(error: unknown): { + delayDurationMs: number; + errorStatus: number | undefined; +} { + const errorStatus = getErrorStatus(error); + let delayDurationMs = 0; + + if (errorStatus === 429) { + delayDurationMs = getRetryAfterDelayMs(error); + } + return { delayDurationMs, errorStatus }; +} + +/** + * Logs a message for a retry attempt when using exponential backoff. + * @param attempt The current attempt number. + * @param error The error that caused the retry. + * @param errorStatus The HTTP status code of the error, if available. + */ +function logRetryAttempt( + attempt: number, + error: unknown, + errorStatus?: number, +): void { + let message = `Attempt ${attempt} failed. Retrying with backoff...`; + if (errorStatus) { + message = `Attempt ${attempt} failed with status ${errorStatus}. Retrying with backoff...`; + } + + if (errorStatus === 429) { + console.warn(message, error); + } else if (errorStatus && errorStatus >= 500 && errorStatus < 600) { + console.error(message, error); + } else if (error instanceof Error) { + // Fallback for errors that might not have a status but have a message + if (error.message.includes('429')) { + console.warn( + `Attempt ${attempt} failed with 429 error (no Retry-After header). Retrying with backoff...`, + error, + ); + } else if (error.message.match(/5\d{2}/)) { + console.error( + `Attempt ${attempt} failed with 5xx error. Retrying with backoff...`, + error, + ); + } else { + console.warn(message, error); // Default to warn for other errors + } + } else { + console.warn(message, error); // Default to warn if error type is unknown + } +}