Introduce loop detection service that breaks simple loop (#3919)

Co-authored-by: Scott Densmore <scottdensmore@mac.com>
Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
Sandy Tao 2025-07-14 20:25:16 -07:00 committed by GitHub
parent 7ffe8038ef
commit 734da8b9d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 456 additions and 1 deletions

View File

@ -141,6 +141,8 @@ export const useGeminiStream = (
[toolCalls], [toolCalls],
); );
const loopDetectedRef = useRef(false);
const onExec = useCallback(async (done: Promise<void>) => { const onExec = useCallback(async (done: Promise<void>) => {
setIsResponding(true); setIsResponding(true);
await done; await done;
@ -450,6 +452,16 @@ export const useGeminiStream = (
[addItem, config], [addItem, config],
); );
const handleLoopDetectedEvent = useCallback(() => {
addItem(
{
type: 'info',
text: `A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.`,
},
Date.now(),
);
}, [addItem]);
const processGeminiStreamEvents = useCallback( const processGeminiStreamEvents = useCallback(
async ( async (
stream: AsyncIterable<GeminiEvent>, stream: AsyncIterable<GeminiEvent>,
@ -489,6 +501,11 @@ export const useGeminiStream = (
case ServerGeminiEventType.MaxSessionTurns: case ServerGeminiEventType.MaxSessionTurns:
handleMaxSessionTurnsEvent(); handleMaxSessionTurnsEvent();
break; break;
case ServerGeminiEventType.LoopDetected:
// handle later because we want to move pending history to history
// before we add loop detected message to history
loopDetectedRef.current = true;
break;
default: { default: {
// enforces exhaustive switch-case // enforces exhaustive switch-case
const unreachable: never = event; const unreachable: never = event;
@ -579,6 +596,10 @@ export const useGeminiStream = (
addItem(pendingHistoryItemRef.current, userMessageTimestamp); addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null); setPendingHistoryItem(null);
} }
if (loopDetectedRef.current) {
loopDetectedRef.current = false;
handleLoopDetectedEvent();
}
} catch (error: unknown) { } catch (error: unknown) {
if (error instanceof UnauthorizedError) { if (error instanceof UnauthorizedError) {
onAuthError(); onAuthError();
@ -616,6 +637,7 @@ export const useGeminiStream = (
config, config,
startNewPrompt, startNewPrompt,
getPromptCount, getPromptCount,
handleLoopDetectedEvent,
], ],
); );

View File

@ -40,6 +40,7 @@ import {
} from './contentGenerator.js'; } from './contentGenerator.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici'; import { ProxyAgent, setGlobalDispatcher } from 'undici';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { LoopDetectionService } from '../services/loopDetectionService.js';
function isThinkingSupported(model: string) { function isThinkingSupported(model: string) {
if (model.startsWith('gemini-2.5')) return true; if (model.startsWith('gemini-2.5')) return true;
@ -100,6 +101,9 @@ export class GeminiClient {
*/ */
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3; private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
private readonly loopDetector = new LoopDetectionService();
private lastPromptId?: string;
constructor(private config: Config) { constructor(private config: Config) {
if (config.getProxy()) { if (config.getProxy()) {
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string)); setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
@ -272,6 +276,10 @@ export class GeminiClient {
turns: number = this.MAX_TURNS, turns: number = this.MAX_TURNS,
originalModel?: string, originalModel?: string,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> { ): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset();
this.lastPromptId = prompt_id;
}
this.sessionTurnCount++; this.sessionTurnCount++;
if ( if (
this.config.getMaxSessionTurns() > 0 && this.config.getMaxSessionTurns() > 0 &&
@ -297,6 +305,10 @@ export class GeminiClient {
const turn = new Turn(this.getChat(), prompt_id); const turn = new Turn(this.getChat(), prompt_id);
const resultStream = turn.run(request, signal); const resultStream = turn.run(request, signal);
for await (const event of resultStream) { for await (const event of resultStream) {
if (this.loopDetector.addAndCheck(event)) {
yield { type: GeminiEventType.LoopDetected };
return turn;
}
yield event; yield event;
} }
if (!turn.pendingToolCalls.length && signal && !signal.aborted) { if (!turn.pendingToolCalls.length && signal && !signal.aborted) {

View File

@ -49,6 +49,7 @@ export enum GeminiEventType {
ChatCompressed = 'chat_compressed', ChatCompressed = 'chat_compressed',
Thought = 'thought', Thought = 'thought',
MaxSessionTurns = 'max_session_turns', MaxSessionTurns = 'max_session_turns',
LoopDetected = 'loop_detected',
} }
export interface StructuredError { export interface StructuredError {
@ -133,6 +134,10 @@ export type ServerGeminiMaxSessionTurnsEvent = {
type: GeminiEventType.MaxSessionTurns; type: GeminiEventType.MaxSessionTurns;
}; };
export type ServerGeminiLoopDetectedEvent = {
type: GeminiEventType.LoopDetected;
};
// The original union type, now composed of the individual types // The original union type, now composed of the individual types
export type ServerGeminiStreamEvent = export type ServerGeminiStreamEvent =
| ServerGeminiContentEvent | ServerGeminiContentEvent
@ -143,7 +148,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiErrorEvent | ServerGeminiErrorEvent
| ServerGeminiChatCompressedEvent | ServerGeminiChatCompressedEvent
| ServerGeminiThoughtEvent | ServerGeminiThoughtEvent
| ServerGeminiMaxSessionTurnsEvent; | ServerGeminiMaxSessionTurnsEvent
| ServerGeminiLoopDetectedEvent;
// A turn manages the agentic loop turn within the server context. // A turn manages the agentic loop turn within the server context.
export class Turn { export class Turn {

View File

@ -0,0 +1,294 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach } from 'vitest';
import { LoopDetectionService } from './loopDetectionService.js';
import {
GeminiEventType,
ServerGeminiContentEvent,
ServerGeminiToolCallRequestEvent,
} from '../core/turn.js';
import { ServerGeminiStreamEvent } from '../core/turn.js';
const TOOL_CALL_LOOP_THRESHOLD = 5;
const CONTENT_LOOP_THRESHOLD = 10;
describe('LoopDetectionService', () => {
let service: LoopDetectionService;
beforeEach(() => {
service = new LoopDetectionService();
});
const createToolCallRequestEvent = (
name: string,
args: Record<string, unknown>,
): ServerGeminiToolCallRequestEvent => ({
type: GeminiEventType.ToolCallRequest,
value: {
name,
args,
callId: 'test-id',
isClientInitiated: false,
prompt_id: 'test-prompt-id',
},
});
const createContentEvent = (content: string): ServerGeminiContentEvent => ({
type: GeminiEventType.Content,
value: content,
});
describe('Tool Call Loop Detection', () => {
it(`should not detect a loop for fewer than TOOL_CALL_LOOP_THRESHOLD identical calls`, () => {
const event = createToolCallRequestEvent('testTool', { param: 'value' });
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD - 1; i++) {
expect(service.addAndCheck(event)).toBe(false);
}
});
it(`should detect a loop on the TOOL_CALL_LOOP_THRESHOLD-th identical call`, () => {
const event = createToolCallRequestEvent('testTool', { param: 'value' });
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(event);
}
expect(service.addAndCheck(event)).toBe(true);
});
it('should detect a loop on subsequent identical calls', () => {
const event = createToolCallRequestEvent('testTool', { param: 'value' });
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD; i++) {
service.addAndCheck(event);
}
expect(service.addAndCheck(event)).toBe(true);
});
it('should not detect a loop for different tool calls', () => {
const event1 = createToolCallRequestEvent('testTool', {
param: 'value1',
});
const event2 = createToolCallRequestEvent('testTool', {
param: 'value2',
});
const event3 = createToolCallRequestEvent('anotherTool', {
param: 'value1',
});
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD - 2; i++) {
expect(service.addAndCheck(event1)).toBe(false);
expect(service.addAndCheck(event2)).toBe(false);
expect(service.addAndCheck(event3)).toBe(false);
}
});
});
describe('Content Loop Detection', () => {
it(`should not detect a loop for fewer than CONTENT_LOOP_THRESHOLD identical content strings`, () => {
const event = createContentEvent('This is a test sentence.');
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
expect(service.addAndCheck(event)).toBe(false);
}
});
it(`should detect a loop on the CONTENT_LOOP_THRESHOLD-th identical content string`, () => {
const event = createContentEvent('This is a test sentence.');
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(event);
}
expect(service.addAndCheck(event)).toBe(true);
});
it('should not detect a loop for different content strings', () => {
const event1 = createContentEvent('Sentence A');
const event2 = createContentEvent('Sentence B');
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 2; i++) {
expect(service.addAndCheck(event1)).toBe(false);
expect(service.addAndCheck(event2)).toBe(false);
}
});
});
describe('Sentence Extraction and Punctuation', () => {
it('should not check for loops when content has no sentence-ending punctuation', () => {
const eventNoPunct = createContentEvent('This has no punctuation');
expect(service.addAndCheck(eventNoPunct)).toBe(false);
const eventWithPunct = createContentEvent('This has punctuation!');
expect(service.addAndCheck(eventWithPunct)).toBe(false);
});
it('should not treat function calls or method calls as sentence endings', () => {
// These should not trigger sentence detection, so repeating them many times should never cause a loop
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 2; i++) {
expect(service.addAndCheck(createContentEvent('console.log()'))).toBe(
false,
);
}
service.reset();
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 2; i++) {
expect(service.addAndCheck(createContentEvent('obj.method()'))).toBe(
false,
);
}
service.reset();
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 2; i++) {
expect(
service.addAndCheck(createContentEvent('arr.filter().map()')),
).toBe(false);
}
service.reset();
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 2; i++) {
expect(
service.addAndCheck(
createContentEvent('if (condition) { return true; }'),
),
).toBe(false);
}
});
it('should correctly identify actual sentence endings and trigger loop detection', () => {
// These should trigger sentence detection, so repeating them should eventually cause a loop
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
expect(
service.addAndCheck(createContentEvent('This is a sentence.')),
).toBe(false);
}
expect(
service.addAndCheck(createContentEvent('This is a sentence.')),
).toBe(true);
service.reset();
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
expect(
service.addAndCheck(createContentEvent('Is this a question? ')),
).toBe(false);
}
expect(
service.addAndCheck(createContentEvent('Is this a question? ')),
).toBe(true);
service.reset();
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
expect(
service.addAndCheck(createContentEvent('What excitement!\n')),
).toBe(false);
}
expect(
service.addAndCheck(createContentEvent('What excitement!\n')),
).toBe(true);
});
it('should handle content with mixed punctuation', () => {
service.addAndCheck(createContentEvent('Question?'));
service.addAndCheck(createContentEvent('Exclamation!'));
service.addAndCheck(createContentEvent('Period.'));
// Repeat one of them multiple times
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent('Period.'));
}
expect(service.addAndCheck(createContentEvent('Period.'))).toBe(true);
});
it('should handle empty sentences after trimming', () => {
service.addAndCheck(createContentEvent(' .'));
expect(service.addAndCheck(createContentEvent('Normal sentence.'))).toBe(
false,
);
});
it('should require at least two sentences for loop detection', () => {
const event = createContentEvent('Only one sentence.');
expect(service.addAndCheck(event)).toBe(false);
// Even repeating the same single sentence shouldn't trigger detection
for (let i = 0; i < 5; i++) {
expect(service.addAndCheck(event)).toBe(false);
}
});
});
describe('Performance Optimizations', () => {
it('should cache sentence extraction and only re-extract when content grows significantly', () => {
// Add initial content
service.addAndCheck(createContentEvent('First sentence.'));
service.addAndCheck(createContentEvent('Second sentence.'));
// Add small amounts of content (shouldn't trigger re-extraction)
for (let i = 0; i < 10; i++) {
service.addAndCheck(createContentEvent('X'));
}
service.addAndCheck(createContentEvent('.'));
// Should still work correctly
expect(service.addAndCheck(createContentEvent('Test.'))).toBe(false);
});
it('should re-extract sentences when content grows by more than 100 characters', () => {
service.addAndCheck(createContentEvent('Initial sentence.'));
// Add enough content to trigger re-extraction
const longContent = 'X'.repeat(101);
service.addAndCheck(createContentEvent(longContent + '.'));
// Should work correctly after re-extraction
expect(service.addAndCheck(createContentEvent('Test.'))).toBe(false);
});
it('should use indexOf for efficient counting instead of regex', () => {
const repeatedSentence = 'This is a repeated sentence.';
// Build up content with the sentence repeated
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
service.addAndCheck(createContentEvent(repeatedSentence));
}
// The threshold should be reached
expect(service.addAndCheck(createContentEvent(repeatedSentence))).toBe(
true,
);
});
});
describe('Edge Cases', () => {
it('should handle empty content', () => {
const event = createContentEvent('');
expect(service.addAndCheck(event)).toBe(false);
});
});
describe('Reset Functionality', () => {
it('tool call should reset content count', () => {
const contentEvent = createContentEvent('Some content.');
const toolEvent = createToolCallRequestEvent('testTool', {
param: 'value',
});
for (let i = 0; i < 9; i++) {
service.addAndCheck(contentEvent);
}
service.addAndCheck(toolEvent);
// Should start fresh
expect(service.addAndCheck(createContentEvent('Fresh content.'))).toBe(
false,
);
});
});
describe('General Behavior', () => {
it('should return false for unhandled event types', () => {
const otherEvent = {
type: 'unhandled_event',
} as unknown as ServerGeminiStreamEvent;
expect(service.addAndCheck(otherEvent)).toBe(false);
expect(service.addAndCheck(otherEvent)).toBe(false);
});
});
});

View File

@ -0,0 +1,121 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { createHash } from 'crypto';
import { GeminiEventType, ServerGeminiStreamEvent } from '../core/turn.js';
const TOOL_CALL_LOOP_THRESHOLD = 5;
const CONTENT_LOOP_THRESHOLD = 10;
const SENTENCE_ENDING_PUNCTUATION_REGEX = /[.!?]+(?=\s|$)/;
/**
* Service for detecting and preventing infinite loops in AI responses.
* Monitors tool call repetitions and content sentence repetitions.
*/
export class LoopDetectionService {
// Tool call tracking
private lastToolCallKey: string | null = null;
private toolCallRepetitionCount: number = 0;
// Content streaming tracking
private lastRepeatedSentence: string = '';
private sentenceRepetitionCount: number = 0;
private partialContent: string = '';
private getToolCallKey(toolCall: { name: string; args: object }): string {
const argsString = JSON.stringify(toolCall.args);
const keyString = `${toolCall.name}:${argsString}`;
return createHash('sha256').update(keyString).digest('hex');
}
/**
* Processes a stream event and checks for loop conditions.
* @param event - The stream event to process
* @returns true if a loop is detected, false otherwise
*/
addAndCheck(event: ServerGeminiStreamEvent): boolean {
switch (event.type) {
case GeminiEventType.ToolCallRequest:
// content chanting only happens in one single stream, reset if there
// is a tool call in between
this.resetSentenceCount();
return this.checkToolCallLoop(event.value);
case GeminiEventType.Content:
return this.checkContentLoop(event.value);
default:
this.reset();
return false;
}
}
private checkToolCallLoop(toolCall: { name: string; args: object }): boolean {
const key = this.getToolCallKey(toolCall);
if (this.lastToolCallKey === key) {
this.toolCallRepetitionCount++;
} else {
this.lastToolCallKey = key;
this.toolCallRepetitionCount = 1;
}
return this.toolCallRepetitionCount >= TOOL_CALL_LOOP_THRESHOLD;
}
private checkContentLoop(content: string): boolean {
this.partialContent += content;
if (!SENTENCE_ENDING_PUNCTUATION_REGEX.test(this.partialContent)) {
return false;
}
const completeSentences =
this.partialContent.match(/[^.!?]+[.!?]+(?=\s|$)/g) || [];
if (completeSentences.length === 0) {
return false;
}
const lastSentence = completeSentences[completeSentences.length - 1];
const lastCompleteIndex = this.partialContent.lastIndexOf(lastSentence);
const endOfLastSentence = lastCompleteIndex + lastSentence.length;
this.partialContent = this.partialContent.slice(endOfLastSentence);
for (const sentence of completeSentences) {
const trimmedSentence = sentence.trim();
if (trimmedSentence === '') {
continue;
}
if (this.lastRepeatedSentence === trimmedSentence) {
this.sentenceRepetitionCount++;
} else {
this.lastRepeatedSentence = trimmedSentence;
this.sentenceRepetitionCount = 1;
}
if (this.sentenceRepetitionCount >= CONTENT_LOOP_THRESHOLD) {
return true;
}
}
return false;
}
/**
* Resets all loop detection state.
*/
reset(): void {
this.resetToolCallCount();
this.resetSentenceCount();
}
private resetToolCallCount(): void {
this.lastToolCallKey = null;
this.toolCallRepetitionCount = 0;
}
private resetSentenceCount(): void {
this.lastRepeatedSentence = '';
this.sentenceRepetitionCount = 0;
this.partialContent = '';
}
}