feat(core): Add infinite loop protection to client (#2793)
This commit is contained in:
parent
3a995305c0
commit
e94decea39
|
@ -402,5 +402,183 @@ describe('Gemini Client (client.ts)', () => {
|
||||||
// Assert
|
// Assert
|
||||||
expect(finalResult).toBeInstanceOf(Turn);
|
expect(finalResult).toBeInstanceOf(Turn);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should stop infinite loop after MAX_TURNS when nextSpeaker always returns model', async () => {
|
||||||
|
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
||||||
|
const { checkNextSpeaker } = await import(
|
||||||
|
'../utils/nextSpeakerChecker.js'
|
||||||
|
);
|
||||||
|
const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker);
|
||||||
|
mockCheckNextSpeaker.mockResolvedValue({
|
||||||
|
next_speaker: 'model',
|
||||||
|
reasoning: 'Test case - always continue',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock Turn to have no pending tool calls (which would allow nextSpeaker check)
|
||||||
|
const mockStream = (async function* () {
|
||||||
|
yield { type: 'content', value: 'Continue...' };
|
||||||
|
})();
|
||||||
|
mockTurnRunFn.mockReturnValue(mockStream);
|
||||||
|
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
addHistory: vi.fn(),
|
||||||
|
getHistory: vi.fn().mockReturnValue([]),
|
||||||
|
};
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
|
||||||
|
const mockGenerator: Partial<ContentGenerator> = {
|
||||||
|
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||||
|
};
|
||||||
|
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||||
|
|
||||||
|
// Use a signal that never gets aborted
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const signal = abortController.signal;
|
||||||
|
|
||||||
|
// Act - Start the stream that should loop
|
||||||
|
const stream = client.sendMessageStream(
|
||||||
|
[{ text: 'Start conversation' }],
|
||||||
|
signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Count how many stream events we get
|
||||||
|
let eventCount = 0;
|
||||||
|
let finalResult: Turn | undefined;
|
||||||
|
|
||||||
|
// Consume the stream and count iterations
|
||||||
|
while (true) {
|
||||||
|
const result = await stream.next();
|
||||||
|
if (result.done) {
|
||||||
|
finalResult = result.value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
eventCount++;
|
||||||
|
|
||||||
|
// Safety check to prevent actual infinite loop in test
|
||||||
|
if (eventCount > 200) {
|
||||||
|
abortController.abort();
|
||||||
|
throw new Error(
|
||||||
|
'Test exceeded expected event limit - possible actual infinite loop',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(finalResult).toBeInstanceOf(Turn);
|
||||||
|
|
||||||
|
// Debug: Check how many times checkNextSpeaker was called
|
||||||
|
const callCount = mockCheckNextSpeaker.mock.calls.length;
|
||||||
|
|
||||||
|
// If infinite loop protection is working, checkNextSpeaker should be called many times
|
||||||
|
// but stop at MAX_TURNS (100). Since each recursive call should trigger checkNextSpeaker,
|
||||||
|
// we expect it to be called multiple times before hitting the limit
|
||||||
|
expect(mockCheckNextSpeaker).toHaveBeenCalled();
|
||||||
|
|
||||||
|
// The test should demonstrate that the infinite loop protection works:
|
||||||
|
// - If checkNextSpeaker is called many times (close to MAX_TURNS), it shows the loop was happening
|
||||||
|
// - If it's only called once, the recursive behavior might not be triggered
|
||||||
|
if (callCount === 0) {
|
||||||
|
throw new Error(
|
||||||
|
'checkNextSpeaker was never called - the recursive condition was not met',
|
||||||
|
);
|
||||||
|
} else if (callCount === 1) {
|
||||||
|
// This might be expected behavior if the turn has pending tool calls or other conditions prevent recursion
|
||||||
|
console.log(
|
||||||
|
'checkNextSpeaker called only once - no infinite loop occurred',
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
console.log(
|
||||||
|
`checkNextSpeaker called ${callCount} times - infinite loop protection worked`,
|
||||||
|
);
|
||||||
|
// If called multiple times, we expect it to be stopped before MAX_TURNS
|
||||||
|
expect(callCount).toBeLessThanOrEqual(100); // Should not exceed MAX_TURNS
|
||||||
|
}
|
||||||
|
|
||||||
|
// The stream should produce events and eventually terminate
|
||||||
|
expect(eventCount).toBeGreaterThanOrEqual(1);
|
||||||
|
expect(eventCount).toBeLessThan(200); // Should not exceed our safety limit
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should respect MAX_TURNS limit even when turns parameter is set to a large value', async () => {
|
||||||
|
// This test verifies that the infinite loop protection works even when
|
||||||
|
// someone tries to bypass it by calling with a very large turns value
|
||||||
|
|
||||||
|
// Get the mocked checkNextSpeaker function and configure it to trigger infinite loop
|
||||||
|
const { checkNextSpeaker } = await import(
|
||||||
|
'../utils/nextSpeakerChecker.js'
|
||||||
|
);
|
||||||
|
const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker);
|
||||||
|
mockCheckNextSpeaker.mockResolvedValue({
|
||||||
|
next_speaker: 'model',
|
||||||
|
reasoning: 'Test case - always continue',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock Turn to have no pending tool calls (which would allow nextSpeaker check)
|
||||||
|
const mockStream = (async function* () {
|
||||||
|
yield { type: 'content', value: 'Continue...' };
|
||||||
|
})();
|
||||||
|
mockTurnRunFn.mockReturnValue(mockStream);
|
||||||
|
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
addHistory: vi.fn(),
|
||||||
|
getHistory: vi.fn().mockReturnValue([]),
|
||||||
|
};
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
|
||||||
|
const mockGenerator: Partial<ContentGenerator> = {
|
||||||
|
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||||
|
};
|
||||||
|
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||||
|
|
||||||
|
// Use a signal that never gets aborted
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const signal = abortController.signal;
|
||||||
|
|
||||||
|
// Act - Start the stream with an extremely high turns value
|
||||||
|
// This simulates a case where the turns protection is bypassed
|
||||||
|
const stream = client.sendMessageStream(
|
||||||
|
[{ text: 'Start conversation' }],
|
||||||
|
signal,
|
||||||
|
Number.MAX_SAFE_INTEGER, // Bypass the MAX_TURNS protection
|
||||||
|
);
|
||||||
|
|
||||||
|
// Count how many stream events we get
|
||||||
|
let eventCount = 0;
|
||||||
|
const maxTestIterations = 1000; // Higher limit to show the loop continues
|
||||||
|
|
||||||
|
// Consume the stream and count iterations
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const result = await stream.next();
|
||||||
|
if (result.done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
eventCount++;
|
||||||
|
|
||||||
|
// This test should hit this limit, demonstrating the infinite loop
|
||||||
|
if (eventCount > maxTestIterations) {
|
||||||
|
abortController.abort();
|
||||||
|
// This is the expected behavior - we hit the infinite loop
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the test framework times out, that also demonstrates the infinite loop
|
||||||
|
console.error('Test timed out or errored:', error);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that the fix works - the loop should stop at MAX_TURNS
|
||||||
|
const callCount = mockCheckNextSpeaker.mock.calls.length;
|
||||||
|
|
||||||
|
// With the fix: even when turns is set to a very high value,
|
||||||
|
// the loop should stop at MAX_TURNS (100)
|
||||||
|
expect(callCount).toBeLessThanOrEqual(100); // Should not exceed MAX_TURNS
|
||||||
|
expect(eventCount).toBeLessThanOrEqual(200); // Should have reasonable number of events
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`Infinite loop protection working: checkNextSpeaker called ${callCount} times, ` +
|
||||||
|
`${eventCount} events generated (properly bounded by MAX_TURNS)`,
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -219,7 +219,9 @@ export class GeminiClient {
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
turns: number = this.MAX_TURNS,
|
turns: number = this.MAX_TURNS,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||||
if (!turns) {
|
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
|
||||||
|
const boundedTurns = Math.min(turns, this.MAX_TURNS);
|
||||||
|
if (!boundedTurns) {
|
||||||
return new Turn(this.getChat());
|
return new Turn(this.getChat());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -242,7 +244,7 @@ export class GeminiClient {
|
||||||
const nextRequest = [{ text: 'Please continue.' }];
|
const nextRequest = [{ text: 'Please continue.' }];
|
||||||
// This recursive call's events will be yielded out, but the final
|
// This recursive call's events will be yielded out, but the final
|
||||||
// turn object will be from the top-level call.
|
// turn object will be from the top-level call.
|
||||||
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
|
yield* this.sendMessageStream(nextRequest, signal, boundedTurns - 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return turn;
|
return turn;
|
||||||
|
|
Loading…
Reference in New Issue