diff --git a/docs/cli/configuration.md b/docs/cli/configuration.md index 79a2ffc3..d175aa4f 100644 --- a/docs/cli/configuration.md +++ b/docs/cli/configuration.md @@ -189,6 +189,14 @@ In addition to a project settings file, a project's `.gemini` directory can cont "hideTips": true ``` +- **`maxSessionTurns`** (number): + - **Description:** Sets the maximum number of turns for a session. If the session exceeds this limit, the CLI will stop processing and start a new chat. + - **Default:** `-1` (unlimited) + - **Example:** + ```json + "maxSessionTurns": 10 + ``` + ### Example `settings.json`: ```json @@ -213,7 +221,8 @@ In addition to a project settings file, a project's `.gemini` directory can cont "logPrompts": true }, "usageStatisticsEnabled": true, - "hideTips": false + "hideTips": false, + "maxSessionTurns": 10 } ``` diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index b80b6dd0..b685f090 100644 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -312,6 +312,7 @@ export async function loadCliConfig( bugCommand: settings.bugCommand, model: argv.model!, extensionContextFilePaths, + maxSessionTurns: settings.maxSessionTurns ?? -1, listExtensions: argv.listExtensions || false, activeExtensions: activeExtensions.map((e) => ({ name: e.config.name, diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts index 133701f5..2abe8cd8 100644 --- a/packages/cli/src/config/settings.ts +++ b/packages/cli/src/config/settings.ts @@ -80,6 +80,9 @@ export interface Settings { hideWindowTitle?: boolean; hideTips?: boolean; + // Setting for setting maximum number of user/model/tool turns in a session. + maxSessionTurns?: number; + // Add other settings here. } diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index 14352f53..6cbb630d 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -53,6 +53,7 @@ describe('runNonInteractive', () => { getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), getContentGeneratorConfig: vi.fn().mockReturnValue({}), + getMaxSessionTurns: vi.fn().mockReturnValue(10), initialize: vi.fn(), } as unknown as Config; @@ -294,4 +295,50 @@ describe('runNonInteractive', () => { 'Unfortunately the tool does not exist.', ); }); + + it('should exit when max session turns are exceeded', async () => { + const functionCall: FunctionCall = { + id: 'fcLoop', + name: 'loopTool', + args: {}, + }; + const toolResponsePart: Part = { + functionResponse: { + name: 'loopTool', + id: 'fcLoop', + response: { result: 'still looping' }, + }, + }; + + // Config with a max turn of 1 + vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(1); + + const { executeToolCall: mockCoreExecuteToolCall } = await import( + '@google/gemini-cli-core' + ); + vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({ + callId: 'fcLoop', + responseParts: [toolResponsePart], + resultDisplay: 'Still looping', + error: undefined, + }); + + const stream = (async function* () { + yield { functionCalls: [functionCall] } as GenerateContentResponse; + })(); + + mockChat.sendMessageStream.mockResolvedValue(stream); + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + await runNonInteractive(mockConfig, 'Trigger loop'); + + expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1); + expect(consoleErrorSpy).toHaveBeenCalledWith( + ` + Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.`, + ); + expect(mockProcessExit).not.toHaveBeenCalled(); + }); }); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index b8b8ac3f..2db28eba 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -63,9 +63,19 @@ export async function runNonInteractive( const chat = await geminiClient.getChat(); const abortController = new AbortController(); let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }]; - + let turnCount = 0; try { while (true) { + turnCount++; + if ( + config.getMaxSessionTurns() > 0 && + turnCount > config.getMaxSessionTurns() + ) { + console.error( + '\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + ); + return; + } const functionCalls: FunctionCall[] = []; const responseStream = await chat.sendMessageStream( diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index b82b0cb2..a9326528 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -431,6 +431,20 @@ export const useGeminiStream = ( [addItem, config], ); + const handleMaxSessionTurnsEvent = useCallback( + () => + addItem( + { + type: 'info', + text: + `The session has reached the maximum number of turns: ${config.getMaxSessionTurns()}. ` + + `Please update this limit in your setting.json file.`, + }, + Date.now(), + ), + [addItem, config], + ); + const processGeminiStreamEvents = useCallback( async ( stream: AsyncIterable, @@ -467,6 +481,9 @@ export const useGeminiStream = ( case ServerGeminiEventType.ToolCallResponse: // do nothing break; + case ServerGeminiEventType.MaxSessionTurns: + handleMaxSessionTurnsEvent(); + break; default: { // enforces exhaustive switch-case const unreachable: never = event; @@ -485,6 +502,7 @@ export const useGeminiStream = ( handleErrorEvent, scheduleToolCalls, handleChatCompressionEvent, + handleMaxSessionTurnsEvent, ], ); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 15e9e73b..12767133 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -139,6 +139,7 @@ export interface ConfigParameters { bugCommand?: BugCommandSettings; model: string; extensionContextFilePaths?: string[]; + maxSessionTurns?: number; listExtensions?: boolean; activeExtensions?: ActiveExtension[]; noBrowser?: boolean; @@ -182,6 +183,7 @@ export class Config { private readonly extensionContextFilePaths: string[]; private readonly noBrowser: boolean; private modelSwitchedDuringSession: boolean = false; + private readonly maxSessionTurns: number; private readonly listExtensions: boolean; private readonly _activeExtensions: ActiveExtension[]; flashFallbackHandler?: FlashFallbackHandler; @@ -227,6 +229,7 @@ export class Config { this.bugCommand = params.bugCommand; this.model = params.model; this.extensionContextFilePaths = params.extensionContextFilePaths ?? []; + this.maxSessionTurns = params.maxSessionTurns ?? -1; this.listExtensions = params.listExtensions ?? false; this._activeExtensions = params.activeExtensions ?? []; this.noBrowser = params.noBrowser ?? false; @@ -308,6 +311,10 @@ export class Config { this.flashFallbackHandler = handler; } + getMaxSessionTurns(): number { + return this.maxSessionTurns; + } + setQuotaErrorOccurred(value: boolean): void { this.quotaErrorOccurred = value; } diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 2769e1b0..bbcb549b 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -17,7 +17,7 @@ import { findIndexAfterFraction, GeminiClient } from './client.js'; import { AuthType, ContentGenerator } from './contentGenerator.js'; import { GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; -import { Turn } from './turn.js'; +import { GeminiEventType, Turn } from './turn.js'; import { getCoreSystemPrompt } from './prompts.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; @@ -43,7 +43,13 @@ vi.mock('./turn', () => { } } // Export the mock class as 'Turn' - return { Turn: MockTurn }; + return { + Turn: MockTurn, + GeminiEventType: { + MaxSessionTurns: 'MaxSessionTurns', + ChatCompressed: 'ChatCompressed', + }, + }; }); vi.mock('../config/config.js'); @@ -68,12 +74,13 @@ vi.mock('../telemetry/index.js', () => ({ describe('findIndexAfterFraction', () => { const history: Content[] = [ - { role: 'user', parts: [{ text: 'This is the first message.' }] }, - { role: 'model', parts: [{ text: 'This is the second message.' }] }, - { role: 'user', parts: [{ text: 'This is the third message.' }] }, - { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, - { role: 'user', parts: [{ text: 'This is the fifth message.' }] }, + { role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 + { role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 + { role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 + { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 + { role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 ]; + // Total length: 333 it('should throw an error for non-positive numbers', () => { expect(() => findIndexAfterFraction(history, 0)).toThrow( @@ -88,14 +95,23 @@ describe('findIndexAfterFraction', () => { }); it('should handle a fraction in the middle', () => { - // Total length is 257. 257 * 0.5 = 128.5 - // 0: 53 - // 1: 53 + 54 = 107 - // 2: 107 + 53 = 160 - // 160 >= 128.5, so index is 2 + // 333 * 0.5 = 166.5 + // 0: 66 + // 1: 66 + 68 = 134 + // 2: 134 + 66 = 200 + // 200 >= 166.5, so index is 2 expect(findIndexAfterFraction(history, 0.5)).toBe(2); }); + it('should handle a fraction that results in the last index', () => { + // 333 * 0.9 = 299.7 + // ... + // 3: 200 + 68 = 268 + // 4: 268 + 65 = 333 + // 333 >= 299.7, so index is 4 + expect(findIndexAfterFraction(history, 0.9)).toBe(4); + }); + it('should handle an empty history', () => { expect(findIndexAfterFraction([], 0.5)).toBe(0); }); @@ -178,6 +194,7 @@ describe('Gemini Client (client.ts)', () => { getProxy: vi.fn().mockReturnValue(undefined), getWorkingDir: vi.fn().mockReturnValue('/test/dir'), getFileService: vi.fn().mockReturnValue(fileService), + getMaxSessionTurns: vi.fn().mockReturnValue(0), getQuotaErrorOccurred: vi.fn().mockReturnValue(false), setQuotaErrorOccurred: vi.fn(), getNoBrowser: vi.fn().mockReturnValue(false), @@ -366,6 +383,42 @@ describe('Gemini Client (client.ts)', () => { contents, }); }); + + it('should allow overriding model and config', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const schema = { type: 'string' }; + const abortSignal = new AbortController().signal; + const customModel = 'custom-json-model'; + const customConfig = { temperature: 0.9, topK: 20 }; + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + await client.generateJson( + contents, + schema, + abortSignal, + customModel, + customConfig, + ); + + expect(mockGenerateContentFn).toHaveBeenCalledWith({ + model: customModel, + config: { + abortSignal, + systemInstruction: getCoreSystemPrompt(''), + temperature: 0.9, + topP: 1, // from default + topK: 20, + responseSchema: schema, + responseMimeType: 'application/json', + }, + contents, + }); + }); }); describe('addHistory', () => { @@ -660,6 +713,59 @@ describe('Gemini Client (client.ts)', () => { expect(eventCount).toBeLessThan(200); // Should not exceed our safety limit }); + it('should yield MaxSessionTurns and stop when session turn limit is reached', async () => { + // Arrange + const MAX_SESSION_TURNS = 5; + vi.spyOn(client['config'], 'getMaxSessionTurns').mockReturnValue( + MAX_SESSION_TURNS, + ); + + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + // Act & Assert + // Run up to the limit + for (let i = 0; i < MAX_SESSION_TURNS; i++) { + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-4', + ); + // consume stream + for await (const _event of stream) { + // do nothing + } + } + + // This call should exceed the limit + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-5', + ); + + const events = []; + for await (const event of stream) { + events.push(event); + } + + expect(events).toEqual([{ type: GeminiEventType.MaxSessionTurns }]); + expect(mockTurnRunFn).toHaveBeenCalledTimes(MAX_SESSION_TURNS); + }); + it('should respect MAX_TURNS limit even when turns parameter is set to a large value', async () => { // This test verifies that the infinite loop protection works even when // someone tries to bypass it by calling with a very large turns value diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 5d9ac0cb..0ff8026b 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -86,6 +86,7 @@ export class GeminiClient { temperature: 0, topP: 1, }; + private sessionTurnCount = 0; private readonly MAX_TURNS = 100; /** * Threshold for compression token count as a fraction of the model's token limit. @@ -266,6 +267,14 @@ export class GeminiClient { turns: number = this.MAX_TURNS, originalModel?: string, ): AsyncGenerator { + this.sessionTurnCount++; + if ( + this.config.getMaxSessionTurns() > 0 && + this.sessionTurnCount > this.config.getMaxSessionTurns() + ) { + yield { type: GeminiEventType.MaxSessionTurns }; + return new Turn(this.getChat(), prompt_id); + } // Ensure turns never exceeds MAX_TURNS to prevent infinite loops const boundedTurns = Math.min(turns, this.MAX_TURNS); if (!boundedTurns) { diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index aeeaa889..6135b1f6 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -48,6 +48,7 @@ export enum GeminiEventType { Error = 'error', ChatCompressed = 'chat_compressed', Thought = 'thought', + MaxSessionTurns = 'max_session_turns', } export interface StructuredError { @@ -128,6 +129,10 @@ export type ServerGeminiChatCompressedEvent = { value: ChatCompressionInfo | null; }; +export type ServerGeminiMaxSessionTurnsEvent = { + type: GeminiEventType.MaxSessionTurns; +}; + // The original union type, now composed of the individual types export type ServerGeminiStreamEvent = | ServerGeminiContentEvent @@ -137,7 +142,8 @@ export type ServerGeminiStreamEvent = | ServerGeminiUserCancelledEvent | ServerGeminiErrorEvent | ServerGeminiChatCompressedEvent - | ServerGeminiThoughtEvent; + | ServerGeminiThoughtEvent + | ServerGeminiMaxSessionTurnsEvent; // A turn manages the agentic loop turn within the server context. export class Turn {