Introduce loop detection service that breaks simple loop (#3919)
Co-authored-by: Scott Densmore <scottdensmore@mac.com> Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
parent
7ffe8038ef
commit
734da8b9d2
|
@ -141,6 +141,8 @@ export const useGeminiStream = (
|
|||
[toolCalls],
|
||||
);
|
||||
|
||||
const loopDetectedRef = useRef(false);
|
||||
|
||||
const onExec = useCallback(async (done: Promise<void>) => {
|
||||
setIsResponding(true);
|
||||
await done;
|
||||
|
@ -450,6 +452,16 @@ export const useGeminiStream = (
|
|||
[addItem, config],
|
||||
);
|
||||
|
||||
const handleLoopDetectedEvent = useCallback(() => {
|
||||
addItem(
|
||||
{
|
||||
type: 'info',
|
||||
text: `A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}, [addItem]);
|
||||
|
||||
const processGeminiStreamEvents = useCallback(
|
||||
async (
|
||||
stream: AsyncIterable<GeminiEvent>,
|
||||
|
@ -489,6 +501,11 @@ export const useGeminiStream = (
|
|||
case ServerGeminiEventType.MaxSessionTurns:
|
||||
handleMaxSessionTurnsEvent();
|
||||
break;
|
||||
case ServerGeminiEventType.LoopDetected:
|
||||
// handle later because we want to move pending history to history
|
||||
// before we add loop detected message to history
|
||||
loopDetectedRef.current = true;
|
||||
break;
|
||||
default: {
|
||||
// enforces exhaustive switch-case
|
||||
const unreachable: never = event;
|
||||
|
@ -579,6 +596,10 @@ export const useGeminiStream = (
|
|||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
if (loopDetectedRef.current) {
|
||||
loopDetectedRef.current = false;
|
||||
handleLoopDetectedEvent();
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof UnauthorizedError) {
|
||||
onAuthError();
|
||||
|
@ -616,6 +637,7 @@ export const useGeminiStream = (
|
|||
config,
|
||||
startNewPrompt,
|
||||
getPromptCount,
|
||||
handleLoopDetectedEvent,
|
||||
],
|
||||
);
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ import {
|
|||
} from './contentGenerator.js';
|
||||
import { ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { LoopDetectionService } from '../services/loopDetectionService.js';
|
||||
|
||||
function isThinkingSupported(model: string) {
|
||||
if (model.startsWith('gemini-2.5')) return true;
|
||||
|
@ -100,6 +101,9 @@ export class GeminiClient {
|
|||
*/
|
||||
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
||||
|
||||
private readonly loopDetector = new LoopDetectionService();
|
||||
private lastPromptId?: string;
|
||||
|
||||
constructor(private config: Config) {
|
||||
if (config.getProxy()) {
|
||||
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
|
||||
|
@ -272,6 +276,10 @@ export class GeminiClient {
|
|||
turns: number = this.MAX_TURNS,
|
||||
originalModel?: string,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
this.loopDetector.reset();
|
||||
this.lastPromptId = prompt_id;
|
||||
}
|
||||
this.sessionTurnCount++;
|
||||
if (
|
||||
this.config.getMaxSessionTurns() > 0 &&
|
||||
|
@ -297,6 +305,10 @@ export class GeminiClient {
|
|||
const turn = new Turn(this.getChat(), prompt_id);
|
||||
const resultStream = turn.run(request, signal);
|
||||
for await (const event of resultStream) {
|
||||
if (this.loopDetector.addAndCheck(event)) {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
return turn;
|
||||
}
|
||||
yield event;
|
||||
}
|
||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||
|
|
|
@ -49,6 +49,7 @@ export enum GeminiEventType {
|
|||
ChatCompressed = 'chat_compressed',
|
||||
Thought = 'thought',
|
||||
MaxSessionTurns = 'max_session_turns',
|
||||
LoopDetected = 'loop_detected',
|
||||
}
|
||||
|
||||
export interface StructuredError {
|
||||
|
@ -133,6 +134,10 @@ export type ServerGeminiMaxSessionTurnsEvent = {
|
|||
type: GeminiEventType.MaxSessionTurns;
|
||||
};
|
||||
|
||||
export type ServerGeminiLoopDetectedEvent = {
|
||||
type: GeminiEventType.LoopDetected;
|
||||
};
|
||||
|
||||
// The original union type, now composed of the individual types
|
||||
export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiContentEvent
|
||||
|
@ -143,7 +148,8 @@ export type ServerGeminiStreamEvent =
|
|||
| ServerGeminiErrorEvent
|
||||
| ServerGeminiChatCompressedEvent
|
||||
| ServerGeminiThoughtEvent
|
||||
| ServerGeminiMaxSessionTurnsEvent;
|
||||
| ServerGeminiMaxSessionTurnsEvent
|
||||
| ServerGeminiLoopDetectedEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
|
|
|
@ -0,0 +1,294 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach } from 'vitest';
|
||||
import { LoopDetectionService } from './loopDetectionService.js';
|
||||
import {
|
||||
GeminiEventType,
|
||||
ServerGeminiContentEvent,
|
||||
ServerGeminiToolCallRequestEvent,
|
||||
} from '../core/turn.js';
|
||||
import { ServerGeminiStreamEvent } from '../core/turn.js';
|
||||
|
||||
const TOOL_CALL_LOOP_THRESHOLD = 5;
|
||||
const CONTENT_LOOP_THRESHOLD = 10;
|
||||
|
||||
describe('LoopDetectionService', () => {
|
||||
let service: LoopDetectionService;
|
||||
|
||||
beforeEach(() => {
|
||||
service = new LoopDetectionService();
|
||||
});
|
||||
|
||||
const createToolCallRequestEvent = (
|
||||
name: string,
|
||||
args: Record<string, unknown>,
|
||||
): ServerGeminiToolCallRequestEvent => ({
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
name,
|
||||
args,
|
||||
callId: 'test-id',
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt-id',
|
||||
},
|
||||
});
|
||||
|
||||
const createContentEvent = (content: string): ServerGeminiContentEvent => ({
|
||||
type: GeminiEventType.Content,
|
||||
value: content,
|
||||
});
|
||||
|
||||
describe('Tool Call Loop Detection', () => {
|
||||
it(`should not detect a loop for fewer than TOOL_CALL_LOOP_THRESHOLD identical calls`, () => {
|
||||
const event = createToolCallRequestEvent('testTool', { param: 'value' });
|
||||
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD - 1; i++) {
|
||||
expect(service.addAndCheck(event)).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it(`should detect a loop on the TOOL_CALL_LOOP_THRESHOLD-th identical call`, () => {
|
||||
const event = createToolCallRequestEvent('testTool', { param: 'value' });
|
||||
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD - 1; i++) {
|
||||
service.addAndCheck(event);
|
||||
}
|
||||
expect(service.addAndCheck(event)).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect a loop on subsequent identical calls', () => {
|
||||
const event = createToolCallRequestEvent('testTool', { param: 'value' });
|
||||
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD; i++) {
|
||||
service.addAndCheck(event);
|
||||
}
|
||||
expect(service.addAndCheck(event)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not detect a loop for different tool calls', () => {
|
||||
const event1 = createToolCallRequestEvent('testTool', {
|
||||
param: 'value1',
|
||||
});
|
||||
const event2 = createToolCallRequestEvent('testTool', {
|
||||
param: 'value2',
|
||||
});
|
||||
const event3 = createToolCallRequestEvent('anotherTool', {
|
||||
param: 'value1',
|
||||
});
|
||||
|
||||
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD - 2; i++) {
|
||||
expect(service.addAndCheck(event1)).toBe(false);
|
||||
expect(service.addAndCheck(event2)).toBe(false);
|
||||
expect(service.addAndCheck(event3)).toBe(false);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Content Loop Detection', () => {
|
||||
it(`should not detect a loop for fewer than CONTENT_LOOP_THRESHOLD identical content strings`, () => {
|
||||
const event = createContentEvent('This is a test sentence.');
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
|
||||
expect(service.addAndCheck(event)).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it(`should detect a loop on the CONTENT_LOOP_THRESHOLD-th identical content string`, () => {
|
||||
const event = createContentEvent('This is a test sentence.');
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
|
||||
service.addAndCheck(event);
|
||||
}
|
||||
expect(service.addAndCheck(event)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not detect a loop for different content strings', () => {
|
||||
const event1 = createContentEvent('Sentence A');
|
||||
const event2 = createContentEvent('Sentence B');
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 2; i++) {
|
||||
expect(service.addAndCheck(event1)).toBe(false);
|
||||
expect(service.addAndCheck(event2)).toBe(false);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Sentence Extraction and Punctuation', () => {
|
||||
it('should not check for loops when content has no sentence-ending punctuation', () => {
|
||||
const eventNoPunct = createContentEvent('This has no punctuation');
|
||||
expect(service.addAndCheck(eventNoPunct)).toBe(false);
|
||||
|
||||
const eventWithPunct = createContentEvent('This has punctuation!');
|
||||
expect(service.addAndCheck(eventWithPunct)).toBe(false);
|
||||
});
|
||||
|
||||
it('should not treat function calls or method calls as sentence endings', () => {
|
||||
// These should not trigger sentence detection, so repeating them many times should never cause a loop
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 2; i++) {
|
||||
expect(service.addAndCheck(createContentEvent('console.log()'))).toBe(
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
service.reset();
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 2; i++) {
|
||||
expect(service.addAndCheck(createContentEvent('obj.method()'))).toBe(
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
service.reset();
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 2; i++) {
|
||||
expect(
|
||||
service.addAndCheck(createContentEvent('arr.filter().map()')),
|
||||
).toBe(false);
|
||||
}
|
||||
|
||||
service.reset();
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD + 2; i++) {
|
||||
expect(
|
||||
service.addAndCheck(
|
||||
createContentEvent('if (condition) { return true; }'),
|
||||
),
|
||||
).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it('should correctly identify actual sentence endings and trigger loop detection', () => {
|
||||
// These should trigger sentence detection, so repeating them should eventually cause a loop
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
|
||||
expect(
|
||||
service.addAndCheck(createContentEvent('This is a sentence.')),
|
||||
).toBe(false);
|
||||
}
|
||||
expect(
|
||||
service.addAndCheck(createContentEvent('This is a sentence.')),
|
||||
).toBe(true);
|
||||
|
||||
service.reset();
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
|
||||
expect(
|
||||
service.addAndCheck(createContentEvent('Is this a question? ')),
|
||||
).toBe(false);
|
||||
}
|
||||
expect(
|
||||
service.addAndCheck(createContentEvent('Is this a question? ')),
|
||||
).toBe(true);
|
||||
|
||||
service.reset();
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
|
||||
expect(
|
||||
service.addAndCheck(createContentEvent('What excitement!\n')),
|
||||
).toBe(false);
|
||||
}
|
||||
expect(
|
||||
service.addAndCheck(createContentEvent('What excitement!\n')),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle content with mixed punctuation', () => {
|
||||
service.addAndCheck(createContentEvent('Question?'));
|
||||
service.addAndCheck(createContentEvent('Exclamation!'));
|
||||
service.addAndCheck(createContentEvent('Period.'));
|
||||
|
||||
// Repeat one of them multiple times
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
|
||||
service.addAndCheck(createContentEvent('Period.'));
|
||||
}
|
||||
expect(service.addAndCheck(createContentEvent('Period.'))).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle empty sentences after trimming', () => {
|
||||
service.addAndCheck(createContentEvent(' .'));
|
||||
expect(service.addAndCheck(createContentEvent('Normal sentence.'))).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
it('should require at least two sentences for loop detection', () => {
|
||||
const event = createContentEvent('Only one sentence.');
|
||||
expect(service.addAndCheck(event)).toBe(false);
|
||||
|
||||
// Even repeating the same single sentence shouldn't trigger detection
|
||||
for (let i = 0; i < 5; i++) {
|
||||
expect(service.addAndCheck(event)).toBe(false);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Performance Optimizations', () => {
|
||||
it('should cache sentence extraction and only re-extract when content grows significantly', () => {
|
||||
// Add initial content
|
||||
service.addAndCheck(createContentEvent('First sentence.'));
|
||||
service.addAndCheck(createContentEvent('Second sentence.'));
|
||||
|
||||
// Add small amounts of content (shouldn't trigger re-extraction)
|
||||
for (let i = 0; i < 10; i++) {
|
||||
service.addAndCheck(createContentEvent('X'));
|
||||
}
|
||||
service.addAndCheck(createContentEvent('.'));
|
||||
|
||||
// Should still work correctly
|
||||
expect(service.addAndCheck(createContentEvent('Test.'))).toBe(false);
|
||||
});
|
||||
|
||||
it('should re-extract sentences when content grows by more than 100 characters', () => {
|
||||
service.addAndCheck(createContentEvent('Initial sentence.'));
|
||||
|
||||
// Add enough content to trigger re-extraction
|
||||
const longContent = 'X'.repeat(101);
|
||||
service.addAndCheck(createContentEvent(longContent + '.'));
|
||||
|
||||
// Should work correctly after re-extraction
|
||||
expect(service.addAndCheck(createContentEvent('Test.'))).toBe(false);
|
||||
});
|
||||
|
||||
it('should use indexOf for efficient counting instead of regex', () => {
|
||||
const repeatedSentence = 'This is a repeated sentence.';
|
||||
|
||||
// Build up content with the sentence repeated
|
||||
for (let i = 0; i < CONTENT_LOOP_THRESHOLD - 1; i++) {
|
||||
service.addAndCheck(createContentEvent(repeatedSentence));
|
||||
}
|
||||
|
||||
// The threshold should be reached
|
||||
expect(service.addAndCheck(createContentEvent(repeatedSentence))).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle empty content', () => {
|
||||
const event = createContentEvent('');
|
||||
expect(service.addAndCheck(event)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Reset Functionality', () => {
|
||||
it('tool call should reset content count', () => {
|
||||
const contentEvent = createContentEvent('Some content.');
|
||||
const toolEvent = createToolCallRequestEvent('testTool', {
|
||||
param: 'value',
|
||||
});
|
||||
for (let i = 0; i < 9; i++) {
|
||||
service.addAndCheck(contentEvent);
|
||||
}
|
||||
|
||||
service.addAndCheck(toolEvent);
|
||||
|
||||
// Should start fresh
|
||||
expect(service.addAndCheck(createContentEvent('Fresh content.'))).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('General Behavior', () => {
|
||||
it('should return false for unhandled event types', () => {
|
||||
const otherEvent = {
|
||||
type: 'unhandled_event',
|
||||
} as unknown as ServerGeminiStreamEvent;
|
||||
expect(service.addAndCheck(otherEvent)).toBe(false);
|
||||
expect(service.addAndCheck(otherEvent)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,121 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { createHash } from 'crypto';
|
||||
import { GeminiEventType, ServerGeminiStreamEvent } from '../core/turn.js';
|
||||
|
||||
const TOOL_CALL_LOOP_THRESHOLD = 5;
|
||||
const CONTENT_LOOP_THRESHOLD = 10;
|
||||
const SENTENCE_ENDING_PUNCTUATION_REGEX = /[.!?]+(?=\s|$)/;
|
||||
|
||||
/**
|
||||
* Service for detecting and preventing infinite loops in AI responses.
|
||||
* Monitors tool call repetitions and content sentence repetitions.
|
||||
*/
|
||||
export class LoopDetectionService {
|
||||
// Tool call tracking
|
||||
private lastToolCallKey: string | null = null;
|
||||
private toolCallRepetitionCount: number = 0;
|
||||
|
||||
// Content streaming tracking
|
||||
private lastRepeatedSentence: string = '';
|
||||
private sentenceRepetitionCount: number = 0;
|
||||
private partialContent: string = '';
|
||||
|
||||
private getToolCallKey(toolCall: { name: string; args: object }): string {
|
||||
const argsString = JSON.stringify(toolCall.args);
|
||||
const keyString = `${toolCall.name}:${argsString}`;
|
||||
return createHash('sha256').update(keyString).digest('hex');
|
||||
}
|
||||
|
||||
/**
|
||||
* Processes a stream event and checks for loop conditions.
|
||||
* @param event - The stream event to process
|
||||
* @returns true if a loop is detected, false otherwise
|
||||
*/
|
||||
addAndCheck(event: ServerGeminiStreamEvent): boolean {
|
||||
switch (event.type) {
|
||||
case GeminiEventType.ToolCallRequest:
|
||||
// content chanting only happens in one single stream, reset if there
|
||||
// is a tool call in between
|
||||
this.resetSentenceCount();
|
||||
return this.checkToolCallLoop(event.value);
|
||||
case GeminiEventType.Content:
|
||||
return this.checkContentLoop(event.value);
|
||||
default:
|
||||
this.reset();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private checkToolCallLoop(toolCall: { name: string; args: object }): boolean {
|
||||
const key = this.getToolCallKey(toolCall);
|
||||
if (this.lastToolCallKey === key) {
|
||||
this.toolCallRepetitionCount++;
|
||||
} else {
|
||||
this.lastToolCallKey = key;
|
||||
this.toolCallRepetitionCount = 1;
|
||||
}
|
||||
return this.toolCallRepetitionCount >= TOOL_CALL_LOOP_THRESHOLD;
|
||||
}
|
||||
|
||||
private checkContentLoop(content: string): boolean {
|
||||
this.partialContent += content;
|
||||
|
||||
if (!SENTENCE_ENDING_PUNCTUATION_REGEX.test(this.partialContent)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const completeSentences =
|
||||
this.partialContent.match(/[^.!?]+[.!?]+(?=\s|$)/g) || [];
|
||||
if (completeSentences.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const lastSentence = completeSentences[completeSentences.length - 1];
|
||||
const lastCompleteIndex = this.partialContent.lastIndexOf(lastSentence);
|
||||
const endOfLastSentence = lastCompleteIndex + lastSentence.length;
|
||||
this.partialContent = this.partialContent.slice(endOfLastSentence);
|
||||
|
||||
for (const sentence of completeSentences) {
|
||||
const trimmedSentence = sentence.trim();
|
||||
if (trimmedSentence === '') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (this.lastRepeatedSentence === trimmedSentence) {
|
||||
this.sentenceRepetitionCount++;
|
||||
} else {
|
||||
this.lastRepeatedSentence = trimmedSentence;
|
||||
this.sentenceRepetitionCount = 1;
|
||||
}
|
||||
|
||||
if (this.sentenceRepetitionCount >= CONTENT_LOOP_THRESHOLD) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets all loop detection state.
|
||||
*/
|
||||
reset(): void {
|
||||
this.resetToolCallCount();
|
||||
this.resetSentenceCount();
|
||||
}
|
||||
|
||||
private resetToolCallCount(): void {
|
||||
this.lastToolCallKey = null;
|
||||
this.toolCallRepetitionCount = 0;
|
||||
}
|
||||
|
||||
private resetSentenceCount(): void {
|
||||
this.lastRepeatedSentence = '';
|
||||
this.sentenceRepetitionCount = 0;
|
||||
this.partialContent = '';
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue