Add support for specifying maxSessionTurns via the settings configuration (#3507)

This commit is contained in:
anj-s 2025-07-11 07:55:03 -07:00 committed by GitHub
parent 0151a9e1a3
commit c9e1e6d3bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 231 additions and 15 deletions

View File

@ -189,6 +189,14 @@ In addition to a project settings file, a project's `.gemini` directory can cont
"hideTips": true "hideTips": true
``` ```
- **`maxSessionTurns`** (number):
- **Description:** Sets the maximum number of turns for a session. If the session exceeds this limit, the CLI will stop processing and start a new chat.
- **Default:** `-1` (unlimited)
- **Example:**
```json
"maxSessionTurns": 10
```
### Example `settings.json`: ### Example `settings.json`:
```json ```json
@ -213,7 +221,8 @@ In addition to a project settings file, a project's `.gemini` directory can cont
"logPrompts": true "logPrompts": true
}, },
"usageStatisticsEnabled": true, "usageStatisticsEnabled": true,
"hideTips": false "hideTips": false,
"maxSessionTurns": 10
} }
``` ```

View File

@ -312,6 +312,7 @@ export async function loadCliConfig(
bugCommand: settings.bugCommand, bugCommand: settings.bugCommand,
model: argv.model!, model: argv.model!,
extensionContextFilePaths, extensionContextFilePaths,
maxSessionTurns: settings.maxSessionTurns ?? -1,
listExtensions: argv.listExtensions || false, listExtensions: argv.listExtensions || false,
activeExtensions: activeExtensions.map((e) => ({ activeExtensions: activeExtensions.map((e) => ({
name: e.config.name, name: e.config.name,

View File

@ -80,6 +80,9 @@ export interface Settings {
hideWindowTitle?: boolean; hideWindowTitle?: boolean;
hideTips?: boolean; hideTips?: boolean;
// Setting for setting maximum number of user/model/tool turns in a session.
maxSessionTurns?: number;
// Add other settings here. // Add other settings here.
} }

View File

@ -53,6 +53,7 @@ describe('runNonInteractive', () => {
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getContentGeneratorConfig: vi.fn().mockReturnValue({}), getContentGeneratorConfig: vi.fn().mockReturnValue({}),
getMaxSessionTurns: vi.fn().mockReturnValue(10),
initialize: vi.fn(), initialize: vi.fn(),
} as unknown as Config; } as unknown as Config;
@ -294,4 +295,50 @@ describe('runNonInteractive', () => {
'Unfortunately the tool does not exist.', 'Unfortunately the tool does not exist.',
); );
}); });
it('should exit when max session turns are exceeded', async () => {
const functionCall: FunctionCall = {
id: 'fcLoop',
name: 'loopTool',
args: {},
};
const toolResponsePart: Part = {
functionResponse: {
name: 'loopTool',
id: 'fcLoop',
response: { result: 'still looping' },
},
};
// Config with a max turn of 1
vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(1);
const { executeToolCall: mockCoreExecuteToolCall } = await import(
'@google/gemini-cli-core'
);
vi.mocked(mockCoreExecuteToolCall).mockResolvedValue({
callId: 'fcLoop',
responseParts: [toolResponsePart],
resultDisplay: 'Still looping',
error: undefined,
});
const stream = (async function* () {
yield { functionCalls: [functionCall] } as GenerateContentResponse;
})();
mockChat.sendMessageStream.mockResolvedValue(stream);
const consoleErrorSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {});
await runNonInteractive(mockConfig, 'Trigger loop');
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(1);
expect(consoleErrorSpy).toHaveBeenCalledWith(
`
Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.`,
);
expect(mockProcessExit).not.toHaveBeenCalled();
});
}); });

View File

@ -63,9 +63,19 @@ export async function runNonInteractive(
const chat = await geminiClient.getChat(); const chat = await geminiClient.getChat();
const abortController = new AbortController(); const abortController = new AbortController();
let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }]; let currentMessages: Content[] = [{ role: 'user', parts: [{ text: input }] }];
let turnCount = 0;
try { try {
while (true) { while (true) {
turnCount++;
if (
config.getMaxSessionTurns() > 0 &&
turnCount > config.getMaxSessionTurns()
) {
console.error(
'\n Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
);
return;
}
const functionCalls: FunctionCall[] = []; const functionCalls: FunctionCall[] = [];
const responseStream = await chat.sendMessageStream( const responseStream = await chat.sendMessageStream(

View File

@ -431,6 +431,20 @@ export const useGeminiStream = (
[addItem, config], [addItem, config],
); );
const handleMaxSessionTurnsEvent = useCallback(
() =>
addItem(
{
type: 'info',
text:
`The session has reached the maximum number of turns: ${config.getMaxSessionTurns()}. ` +
`Please update this limit in your setting.json file.`,
},
Date.now(),
),
[addItem, config],
);
const processGeminiStreamEvents = useCallback( const processGeminiStreamEvents = useCallback(
async ( async (
stream: AsyncIterable<GeminiEvent>, stream: AsyncIterable<GeminiEvent>,
@ -467,6 +481,9 @@ export const useGeminiStream = (
case ServerGeminiEventType.ToolCallResponse: case ServerGeminiEventType.ToolCallResponse:
// do nothing // do nothing
break; break;
case ServerGeminiEventType.MaxSessionTurns:
handleMaxSessionTurnsEvent();
break;
default: { default: {
// enforces exhaustive switch-case // enforces exhaustive switch-case
const unreachable: never = event; const unreachable: never = event;
@ -485,6 +502,7 @@ export const useGeminiStream = (
handleErrorEvent, handleErrorEvent,
scheduleToolCalls, scheduleToolCalls,
handleChatCompressionEvent, handleChatCompressionEvent,
handleMaxSessionTurnsEvent,
], ],
); );

View File

@ -139,6 +139,7 @@ export interface ConfigParameters {
bugCommand?: BugCommandSettings; bugCommand?: BugCommandSettings;
model: string; model: string;
extensionContextFilePaths?: string[]; extensionContextFilePaths?: string[];
maxSessionTurns?: number;
listExtensions?: boolean; listExtensions?: boolean;
activeExtensions?: ActiveExtension[]; activeExtensions?: ActiveExtension[];
noBrowser?: boolean; noBrowser?: boolean;
@ -182,6 +183,7 @@ export class Config {
private readonly extensionContextFilePaths: string[]; private readonly extensionContextFilePaths: string[];
private readonly noBrowser: boolean; private readonly noBrowser: boolean;
private modelSwitchedDuringSession: boolean = false; private modelSwitchedDuringSession: boolean = false;
private readonly maxSessionTurns: number;
private readonly listExtensions: boolean; private readonly listExtensions: boolean;
private readonly _activeExtensions: ActiveExtension[]; private readonly _activeExtensions: ActiveExtension[];
flashFallbackHandler?: FlashFallbackHandler; flashFallbackHandler?: FlashFallbackHandler;
@ -227,6 +229,7 @@ export class Config {
this.bugCommand = params.bugCommand; this.bugCommand = params.bugCommand;
this.model = params.model; this.model = params.model;
this.extensionContextFilePaths = params.extensionContextFilePaths ?? []; this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
this.maxSessionTurns = params.maxSessionTurns ?? -1;
this.listExtensions = params.listExtensions ?? false; this.listExtensions = params.listExtensions ?? false;
this._activeExtensions = params.activeExtensions ?? []; this._activeExtensions = params.activeExtensions ?? [];
this.noBrowser = params.noBrowser ?? false; this.noBrowser = params.noBrowser ?? false;
@ -308,6 +311,10 @@ export class Config {
this.flashFallbackHandler = handler; this.flashFallbackHandler = handler;
} }
getMaxSessionTurns(): number {
return this.maxSessionTurns;
}
setQuotaErrorOccurred(value: boolean): void { setQuotaErrorOccurred(value: boolean): void {
this.quotaErrorOccurred = value; this.quotaErrorOccurred = value;
} }

View File

@ -17,7 +17,7 @@ import { findIndexAfterFraction, GeminiClient } from './client.js';
import { AuthType, ContentGenerator } from './contentGenerator.js'; import { AuthType, ContentGenerator } from './contentGenerator.js';
import { GeminiChat } from './geminiChat.js'; import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { Turn } from './turn.js'; import { GeminiEventType, Turn } from './turn.js';
import { getCoreSystemPrompt } from './prompts.js'; import { getCoreSystemPrompt } from './prompts.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
@ -43,7 +43,13 @@ vi.mock('./turn', () => {
} }
} }
// Export the mock class as 'Turn' // Export the mock class as 'Turn'
return { Turn: MockTurn }; return {
Turn: MockTurn,
GeminiEventType: {
MaxSessionTurns: 'MaxSessionTurns',
ChatCompressed: 'ChatCompressed',
},
};
}); });
vi.mock('../config/config.js'); vi.mock('../config/config.js');
@ -68,12 +74,13 @@ vi.mock('../telemetry/index.js', () => ({
describe('findIndexAfterFraction', () => { describe('findIndexAfterFraction', () => {
const history: Content[] = [ const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, { role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, { role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, { role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, { role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, { role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65
]; ];
// Total length: 333
it('should throw an error for non-positive numbers', () => { it('should throw an error for non-positive numbers', () => {
expect(() => findIndexAfterFraction(history, 0)).toThrow( expect(() => findIndexAfterFraction(history, 0)).toThrow(
@ -88,14 +95,23 @@ describe('findIndexAfterFraction', () => {
}); });
it('should handle a fraction in the middle', () => { it('should handle a fraction in the middle', () => {
// Total length is 257. 257 * 0.5 = 128.5 // 333 * 0.5 = 166.5
// 0: 53 // 0: 66
// 1: 53 + 54 = 107 // 1: 66 + 68 = 134
// 2: 107 + 53 = 160 // 2: 134 + 66 = 200
// 160 >= 128.5, so index is 2 // 200 >= 166.5, so index is 2
expect(findIndexAfterFraction(history, 0.5)).toBe(2); expect(findIndexAfterFraction(history, 0.5)).toBe(2);
}); });
it('should handle a fraction that results in the last index', () => {
// 333 * 0.9 = 299.7
// ...
// 3: 200 + 68 = 268
// 4: 268 + 65 = 333
// 333 >= 299.7, so index is 4
expect(findIndexAfterFraction(history, 0.9)).toBe(4);
});
it('should handle an empty history', () => { it('should handle an empty history', () => {
expect(findIndexAfterFraction([], 0.5)).toBe(0); expect(findIndexAfterFraction([], 0.5)).toBe(0);
}); });
@ -178,6 +194,7 @@ describe('Gemini Client (client.ts)', () => {
getProxy: vi.fn().mockReturnValue(undefined), getProxy: vi.fn().mockReturnValue(undefined),
getWorkingDir: vi.fn().mockReturnValue('/test/dir'), getWorkingDir: vi.fn().mockReturnValue('/test/dir'),
getFileService: vi.fn().mockReturnValue(fileService), getFileService: vi.fn().mockReturnValue(fileService),
getMaxSessionTurns: vi.fn().mockReturnValue(0),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false), getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(), setQuotaErrorOccurred: vi.fn(),
getNoBrowser: vi.fn().mockReturnValue(false), getNoBrowser: vi.fn().mockReturnValue(false),
@ -366,6 +383,42 @@ describe('Gemini Client (client.ts)', () => {
contents, contents,
}); });
}); });
it('should allow overriding model and config', async () => {
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
const schema = { type: 'string' };
const abortSignal = new AbortController().signal;
const customModel = 'custom-json-model';
const customConfig = { temperature: 0.9, topK: 20 };
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 1 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateJson(
contents,
schema,
abortSignal,
customModel,
customConfig,
);
expect(mockGenerateContentFn).toHaveBeenCalledWith({
model: customModel,
config: {
abortSignal,
systemInstruction: getCoreSystemPrompt(''),
temperature: 0.9,
topP: 1, // from default
topK: 20,
responseSchema: schema,
responseMimeType: 'application/json',
},
contents,
});
});
}); });
describe('addHistory', () => { describe('addHistory', () => {
@ -660,6 +713,59 @@ describe('Gemini Client (client.ts)', () => {
expect(eventCount).toBeLessThan(200); // Should not exceed our safety limit expect(eventCount).toBeLessThan(200); // Should not exceed our safety limit
}); });
it('should yield MaxSessionTurns and stop when session turn limit is reached', async () => {
// Arrange
const MAX_SESSION_TURNS = 5;
vi.spyOn(client['config'], 'getMaxSessionTurns').mockReturnValue(
MAX_SESSION_TURNS,
);
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act & Assert
// Run up to the limit
for (let i = 0; i < MAX_SESSION_TURNS; i++) {
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-4',
);
// consume stream
for await (const _event of stream) {
// do nothing
}
}
// This call should exceed the limit
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-5',
);
const events = [];
for await (const event of stream) {
events.push(event);
}
expect(events).toEqual([{ type: GeminiEventType.MaxSessionTurns }]);
expect(mockTurnRunFn).toHaveBeenCalledTimes(MAX_SESSION_TURNS);
});
it('should respect MAX_TURNS limit even when turns parameter is set to a large value', async () => { it('should respect MAX_TURNS limit even when turns parameter is set to a large value', async () => {
// This test verifies that the infinite loop protection works even when // This test verifies that the infinite loop protection works even when
// someone tries to bypass it by calling with a very large turns value // someone tries to bypass it by calling with a very large turns value

View File

@ -86,6 +86,7 @@ export class GeminiClient {
temperature: 0, temperature: 0,
topP: 1, topP: 1,
}; };
private sessionTurnCount = 0;
private readonly MAX_TURNS = 100; private readonly MAX_TURNS = 100;
/** /**
* Threshold for compression token count as a fraction of the model's token limit. * Threshold for compression token count as a fraction of the model's token limit.
@ -266,6 +267,14 @@ export class GeminiClient {
turns: number = this.MAX_TURNS, turns: number = this.MAX_TURNS,
originalModel?: string, originalModel?: string,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> { ): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
this.sessionTurnCount++;
if (
this.config.getMaxSessionTurns() > 0 &&
this.sessionTurnCount > this.config.getMaxSessionTurns()
) {
yield { type: GeminiEventType.MaxSessionTurns };
return new Turn(this.getChat(), prompt_id);
}
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops // Ensure turns never exceeds MAX_TURNS to prevent infinite loops
const boundedTurns = Math.min(turns, this.MAX_TURNS); const boundedTurns = Math.min(turns, this.MAX_TURNS);
if (!boundedTurns) { if (!boundedTurns) {

View File

@ -48,6 +48,7 @@ export enum GeminiEventType {
Error = 'error', Error = 'error',
ChatCompressed = 'chat_compressed', ChatCompressed = 'chat_compressed',
Thought = 'thought', Thought = 'thought',
MaxSessionTurns = 'max_session_turns',
} }
export interface StructuredError { export interface StructuredError {
@ -128,6 +129,10 @@ export type ServerGeminiChatCompressedEvent = {
value: ChatCompressionInfo | null; value: ChatCompressionInfo | null;
}; };
export type ServerGeminiMaxSessionTurnsEvent = {
type: GeminiEventType.MaxSessionTurns;
};
// The original union type, now composed of the individual types // The original union type, now composed of the individual types
export type ServerGeminiStreamEvent = export type ServerGeminiStreamEvent =
| ServerGeminiContentEvent | ServerGeminiContentEvent
@ -137,7 +142,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiUserCancelledEvent | ServerGeminiUserCancelledEvent
| ServerGeminiErrorEvent | ServerGeminiErrorEvent
| ServerGeminiChatCompressedEvent | ServerGeminiChatCompressedEvent
| ServerGeminiThoughtEvent; | ServerGeminiThoughtEvent
| ServerGeminiMaxSessionTurnsEvent;
// A turn manages the agentic loop turn within the server context. // A turn manages the agentic loop turn within the server context.
export class Turn { export class Turn {