fix: Add warning message for token limit truncation (#2260)
Co-authored-by: Sandy Tao <sandytao520@icloud.com>
This commit is contained in:
parent
dc2ac144b7
commit
4c3532d2b3
|
@ -16,7 +16,12 @@ import {
|
|||
TrackedExecutingToolCall,
|
||||
TrackedCancelledToolCall,
|
||||
} from './useReactToolScheduler.js';
|
||||
import { Config, EditorType, AuthType } from '@google/gemini-cli-core';
|
||||
import {
|
||||
Config,
|
||||
EditorType,
|
||||
AuthType,
|
||||
GeminiEventType as ServerGeminiEventType,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { Part, PartListUnion } from '@google/genai';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import {
|
||||
|
@ -1178,4 +1183,235 @@ describe('useGeminiStream', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleFinishedEvent', () => {
|
||||
it('should add info message for MAX_TOKENS finish reason', async () => {
|
||||
// Setup mock to return a stream with MAX_TOKENS finish reason
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'This is a truncated response...',
|
||||
};
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'MAX_TOKENS' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockSetShowHelp,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
// Submit a query
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Generate long text');
|
||||
});
|
||||
|
||||
// Check that the info message was added
|
||||
await waitFor(() => {
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{
|
||||
type: 'info',
|
||||
text: '⚠️ Response truncated due to token limits.',
|
||||
},
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not add message for STOP finish reason', async () => {
|
||||
// Setup mock to return a stream with STOP finish reason
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'Complete response',
|
||||
};
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockSetShowHelp,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
// Submit a query
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Test normal completion');
|
||||
});
|
||||
|
||||
// Wait a bit to ensure no message is added
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Check that no info message was added for STOP
|
||||
const infoMessages = mockAddItem.mock.calls.filter(
|
||||
(call) => call[0].type === 'info',
|
||||
);
|
||||
expect(infoMessages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should not add message for FINISH_REASON_UNSPECIFIED', async () => {
|
||||
// Setup mock to return a stream with FINISH_REASON_UNSPECIFIED
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'Response with unspecified finish',
|
||||
};
|
||||
yield {
|
||||
type: ServerGeminiEventType.Finished,
|
||||
value: 'FINISH_REASON_UNSPECIFIED',
|
||||
};
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockSetShowHelp,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
// Submit a query
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Test unspecified finish');
|
||||
});
|
||||
|
||||
// Wait a bit to ensure no message is added
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Check that no info message was added
|
||||
const infoMessages = mockAddItem.mock.calls.filter(
|
||||
(call) => call[0].type === 'info',
|
||||
);
|
||||
expect(infoMessages).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should add appropriate messages for other finish reasons', async () => {
|
||||
const testCases = [
|
||||
{
|
||||
reason: 'SAFETY',
|
||||
message: '⚠️ Response stopped due to safety reasons.',
|
||||
},
|
||||
{
|
||||
reason: 'RECITATION',
|
||||
message: '⚠️ Response stopped due to recitation policy.',
|
||||
},
|
||||
{
|
||||
reason: 'LANGUAGE',
|
||||
message: '⚠️ Response stopped due to unsupported language.',
|
||||
},
|
||||
{
|
||||
reason: 'BLOCKLIST',
|
||||
message: '⚠️ Response stopped due to forbidden terms.',
|
||||
},
|
||||
{
|
||||
reason: 'PROHIBITED_CONTENT',
|
||||
message: '⚠️ Response stopped due to prohibited content.',
|
||||
},
|
||||
{
|
||||
reason: 'SPII',
|
||||
message:
|
||||
'⚠️ Response stopped due to sensitive personally identifiable information.',
|
||||
},
|
||||
{ reason: 'OTHER', message: '⚠️ Response stopped for other reasons.' },
|
||||
{
|
||||
reason: 'MALFORMED_FUNCTION_CALL',
|
||||
message: '⚠️ Response stopped due to malformed function call.',
|
||||
},
|
||||
{
|
||||
reason: 'IMAGE_SAFETY',
|
||||
message: '⚠️ Response stopped due to image safety violations.',
|
||||
},
|
||||
{
|
||||
reason: 'UNEXPECTED_TOOL_CALL',
|
||||
message: '⚠️ Response stopped due to unexpected tool call.',
|
||||
},
|
||||
];
|
||||
|
||||
for (const { reason, message } of testCases) {
|
||||
// Reset mocks for each test case
|
||||
mockAddItem.mockClear();
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: `Response for ${reason}`,
|
||||
};
|
||||
yield { type: ServerGeminiEventType.Finished, value: reason };
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
new MockedGeminiClientClass(mockConfig),
|
||||
[],
|
||||
mockAddItem,
|
||||
mockSetShowHelp,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
() => 'vscode' as EditorType,
|
||||
() => {},
|
||||
() => Promise.resolve(),
|
||||
false,
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery(`Test ${reason}`);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{
|
||||
type: 'info',
|
||||
text: message,
|
||||
},
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -14,6 +14,7 @@ import {
|
|||
ServerGeminiContentEvent as ContentEvent,
|
||||
ServerGeminiErrorEvent as ErrorEvent,
|
||||
ServerGeminiChatCompressedEvent,
|
||||
ServerGeminiFinishedEvent,
|
||||
getErrorMessage,
|
||||
isNodeError,
|
||||
MessageSenderType,
|
||||
|
@ -26,7 +27,7 @@ import {
|
|||
UserPromptEvent,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { type Part, type PartListUnion } from '@google/genai';
|
||||
import { type Part, type PartListUnion, FinishReason } from '@google/genai';
|
||||
import {
|
||||
StreamingState,
|
||||
HistoryItem,
|
||||
|
@ -422,6 +423,46 @@ export const useGeminiStream = (
|
|||
[addItem, pendingHistoryItemRef, setPendingHistoryItem, config],
|
||||
);
|
||||
|
||||
const handleFinishedEvent = useCallback(
|
||||
(event: ServerGeminiFinishedEvent, userMessageTimestamp: number) => {
|
||||
const finishReason = event.value;
|
||||
|
||||
const finishReasonMessages: Record<FinishReason, string | undefined> = {
|
||||
[FinishReason.FINISH_REASON_UNSPECIFIED]: undefined,
|
||||
[FinishReason.STOP]: undefined,
|
||||
[FinishReason.MAX_TOKENS]: 'Response truncated due to token limits.',
|
||||
[FinishReason.SAFETY]: 'Response stopped due to safety reasons.',
|
||||
[FinishReason.RECITATION]: 'Response stopped due to recitation policy.',
|
||||
[FinishReason.LANGUAGE]:
|
||||
'Response stopped due to unsupported language.',
|
||||
[FinishReason.BLOCKLIST]: 'Response stopped due to forbidden terms.',
|
||||
[FinishReason.PROHIBITED_CONTENT]:
|
||||
'Response stopped due to prohibited content.',
|
||||
[FinishReason.SPII]:
|
||||
'Response stopped due to sensitive personally identifiable information.',
|
||||
[FinishReason.OTHER]: 'Response stopped for other reasons.',
|
||||
[FinishReason.MALFORMED_FUNCTION_CALL]:
|
||||
'Response stopped due to malformed function call.',
|
||||
[FinishReason.IMAGE_SAFETY]:
|
||||
'Response stopped due to image safety violations.',
|
||||
[FinishReason.UNEXPECTED_TOOL_CALL]:
|
||||
'Response stopped due to unexpected tool call.',
|
||||
};
|
||||
|
||||
const message = finishReasonMessages[finishReason];
|
||||
if (message) {
|
||||
addItem(
|
||||
{
|
||||
type: 'info',
|
||||
text: `⚠️ ${message}`,
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
}
|
||||
},
|
||||
[addItem],
|
||||
);
|
||||
|
||||
const handleChatCompressionEvent = useCallback(
|
||||
(eventValue: ServerGeminiChatCompressedEvent['value']) =>
|
||||
addItem(
|
||||
|
@ -501,6 +542,12 @@ export const useGeminiStream = (
|
|||
case ServerGeminiEventType.MaxSessionTurns:
|
||||
handleMaxSessionTurnsEvent();
|
||||
break;
|
||||
case ServerGeminiEventType.Finished:
|
||||
handleFinishedEvent(
|
||||
event as ServerGeminiFinishedEvent,
|
||||
userMessageTimestamp,
|
||||
);
|
||||
break;
|
||||
case ServerGeminiEventType.LoopDetected:
|
||||
// handle later because we want to move pending history to history
|
||||
// before we add loop detected message to history
|
||||
|
@ -524,6 +571,7 @@ export const useGeminiStream = (
|
|||
handleErrorEvent,
|
||||
scheduleToolCalls,
|
||||
handleChatCompressionEvent,
|
||||
handleFinishedEvent,
|
||||
handleMaxSessionTurnsEvent,
|
||||
],
|
||||
);
|
||||
|
|
|
@ -282,6 +282,165 @@ describe('Turn', () => {
|
|||
expect(turn.pendingToolCalls[2]).toEqual(event3.value);
|
||||
expect(turn.getDebugResponses().length).toBe(1);
|
||||
});
|
||||
|
||||
it('should yield finished event when response has finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Partial response' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test finish reason' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Partial response' },
|
||||
{ type: GeminiEventType.Finished, value: 'STOP' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should yield finished event for MAX_TOKENS finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{ text: 'This is a long response that was cut off...' },
|
||||
],
|
||||
},
|
||||
finishReason: 'MAX_TOKENS',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Generate long text' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toEqual([
|
||||
{
|
||||
type: GeminiEventType.Content,
|
||||
value: 'This is a long response that was cut off...',
|
||||
},
|
||||
{ type: GeminiEventType.Finished, value: 'MAX_TOKENS' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should yield finished event for SAFETY finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Content blocked' }] },
|
||||
finishReason: 'SAFETY',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test safety' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Content blocked' },
|
||||
{ type: GeminiEventType.Finished, value: 'SAFETY' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should not yield finished event when there is no finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Response without finish reason' }] },
|
||||
// No finishReason property
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test no finish reason' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toEqual([
|
||||
{
|
||||
type: GeminiEventType.Content,
|
||||
value: 'Response without finish reason',
|
||||
},
|
||||
]);
|
||||
// No Finished event should be emitted
|
||||
});
|
||||
|
||||
it('should handle multiple responses with different finish reasons', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'First part' }] },
|
||||
// No finish reason on first response
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Second part' }] },
|
||||
finishReason: 'OTHER',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test multiple responses' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'First part' },
|
||||
{ type: GeminiEventType.Content, value: 'Second part' },
|
||||
{ type: GeminiEventType.Finished, value: 'OTHER' },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDebugResponses', () => {
|
||||
|
|
|
@ -9,6 +9,7 @@ import {
|
|||
GenerateContentResponse,
|
||||
FunctionCall,
|
||||
FunctionDeclaration,
|
||||
FinishReason,
|
||||
} from '@google/genai';
|
||||
import {
|
||||
ToolCallConfirmationDetails,
|
||||
|
@ -49,6 +50,7 @@ export enum GeminiEventType {
|
|||
ChatCompressed = 'chat_compressed',
|
||||
Thought = 'thought',
|
||||
MaxSessionTurns = 'max_session_turns',
|
||||
Finished = 'finished',
|
||||
LoopDetected = 'loop_detected',
|
||||
}
|
||||
|
||||
|
@ -134,6 +136,11 @@ export type ServerGeminiMaxSessionTurnsEvent = {
|
|||
type: GeminiEventType.MaxSessionTurns;
|
||||
};
|
||||
|
||||
export type ServerGeminiFinishedEvent = {
|
||||
type: GeminiEventType.Finished;
|
||||
value: FinishReason;
|
||||
};
|
||||
|
||||
export type ServerGeminiLoopDetectedEvent = {
|
||||
type: GeminiEventType.LoopDetected;
|
||||
};
|
||||
|
@ -149,6 +156,7 @@ export type ServerGeminiStreamEvent =
|
|||
| ServerGeminiChatCompressedEvent
|
||||
| ServerGeminiThoughtEvent
|
||||
| ServerGeminiMaxSessionTurnsEvent
|
||||
| ServerGeminiFinishedEvent
|
||||
| ServerGeminiLoopDetectedEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
|
@ -222,6 +230,16 @@ export class Turn {
|
|||
yield event;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if response was truncated or stopped for various reasons
|
||||
const finishReason = resp.candidates?.[0]?.finishReason;
|
||||
|
||||
if (finishReason) {
|
||||
yield {
|
||||
type: GeminiEventType.Finished,
|
||||
value: finishReason as FinishReason,
|
||||
};
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
const error = toFriendlyError(e);
|
||||
|
|
Loading…
Reference in New Issue