gemini-cli/packages/core/src/core/client.ts

733 lines
22 KiB
TypeScript

/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
EmbedContentParameters,
GenerateContentConfig,
Part,
SchemaUnion,
PartListUnion,
Content,
Tool,
GenerateContentResponse,
} from '@google/genai';
import { getFolderStructure } from '../utils/getFolderStructure.js';
import {
Turn,
ServerGeminiStreamEvent,
GeminiEventType,
ChatCompressionInfo,
} from './turn.js';
import { Config } from '../config/config.js';
import { UserTierId } from '../code_assist/types.js';
import { getCoreSystemPrompt, getCompressionPrompt } from './prompts.js';
import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js';
import { retryWithBackoff } from '../utils/retry.js';
import { getErrorMessage } from '../utils/errors.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
import { tokenLimit } from './tokenLimits.js';
import {
AuthType,
ContentGenerator,
ContentGeneratorConfig,
createContentGenerator,
} from './contentGenerator.js';
import { ProxyAgent, setGlobalDispatcher } from 'undici';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { LoopDetectionService } from '../services/loopDetectionService.js';
import { ideContext } from '../ide/ideContext.js';
import { logFlashDecidedToContinue } from '../telemetry/loggers.js';
import { FlashDecidedToContinueEvent } from '../telemetry/types.js';
function isThinkingSupported(model: string) {
if (model.startsWith('gemini-2.5')) return true;
return false;
}
/**
* Returns the index of the content after the fraction of the total characters in the history.
*
* Exported for testing purposes.
*/
export function findIndexAfterFraction(
history: Content[],
fraction: number,
): number {
if (fraction <= 0 || fraction >= 1) {
throw new Error('Fraction must be between 0 and 1');
}
const contentLengths = history.map(
(content) => JSON.stringify(content).length,
);
const totalCharacters = contentLengths.reduce(
(sum, length) => sum + length,
0,
);
const targetCharacters = totalCharacters * fraction;
let charactersSoFar = 0;
for (let i = 0; i < contentLengths.length; i++) {
charactersSoFar += contentLengths[i];
if (charactersSoFar >= targetCharacters) {
return i;
}
}
return contentLengths.length;
}
export class GeminiClient {
private chat?: GeminiChat;
private contentGenerator?: ContentGenerator;
private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
};
private sessionTurnCount = 0;
private readonly MAX_TURNS = 100;
/**
* Threshold for compression token count as a fraction of the model's token limit.
* If the chat history exceeds this threshold, it will be compressed.
*/
private readonly COMPRESSION_TOKEN_THRESHOLD = 0.7;
/**
* The fraction of the latest chat history to keep. A value of 0.3
* means that only the last 30% of the chat history will be kept after compression.
*/
private readonly COMPRESSION_PRESERVE_THRESHOLD = 0.3;
private readonly loopDetector: LoopDetectionService;
private lastPromptId: string;
constructor(private config: Config) {
if (config.getProxy()) {
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
}
this.embeddingModel = config.getEmbeddingModel();
this.loopDetector = new LoopDetectionService(config);
this.lastPromptId = this.config.getSessionId();
}
async initialize(contentGeneratorConfig: ContentGeneratorConfig) {
this.contentGenerator = await createContentGenerator(
contentGeneratorConfig,
this.config,
this.config.getSessionId(),
);
this.chat = await this.startChat();
}
getContentGenerator(): ContentGenerator {
if (!this.contentGenerator) {
throw new Error('Content generator not initialized');
}
return this.contentGenerator;
}
getUserTier(): UserTierId | undefined {
return this.contentGenerator?.userTier;
}
async addHistory(content: Content) {
this.getChat().addHistory(content);
}
getChat(): GeminiChat {
if (!this.chat) {
throw new Error('Chat not initialized');
}
return this.chat;
}
isInitialized(): boolean {
return this.chat !== undefined && this.contentGenerator !== undefined;
}
getHistory(): Content[] {
return this.getChat().getHistory();
}
setHistory(history: Content[]) {
this.getChat().setHistory(history);
}
async setTools(): Promise<void> {
const toolRegistry = await this.config.getToolRegistry();
const toolDeclarations = toolRegistry.getFunctionDeclarations();
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
this.getChat().setTools(tools);
}
async resetChat(): Promise<void> {
this.chat = await this.startChat();
}
private async getEnvironment(): Promise<Part[]> {
const cwd = this.config.getWorkingDir();
const today = new Date().toLocaleDateString(undefined, {
weekday: 'long',
year: 'numeric',
month: 'long',
day: 'numeric',
});
const platform = process.platform;
const folderStructure = await getFolderStructure(cwd, {
fileService: this.config.getFileService(),
});
const context = `
This is the Gemini CLI. We are setting up the context for our chat.
Today's date is ${today}.
My operating system is: ${platform}
I'm currently working in the directory: ${cwd}
${folderStructure}
`.trim();
const initialParts: Part[] = [{ text: context }];
const toolRegistry = await this.config.getToolRegistry();
// Add full file context if the flag is set
if (this.config.getFullContext()) {
try {
const readManyFilesTool = toolRegistry.getTool(
'read_many_files',
) as ReadManyFilesTool;
if (readManyFilesTool) {
// Read all files in the target directory
const result = await readManyFilesTool.execute(
{
paths: ['**/*'], // Read everything recursively
useDefaultExcludes: true, // Use default excludes
},
AbortSignal.timeout(30000),
);
if (result.llmContent) {
initialParts.push({
text: `\n--- Full File Context ---\n${result.llmContent}`,
});
} else {
console.warn(
'Full context requested, but read_many_files returned no content.',
);
}
} else {
console.warn(
'Full context requested, but read_many_files tool not found.',
);
}
} catch (error) {
// Not using reportError here as it's a startup/config phase, not a chat/generation phase error.
console.error('Error reading full file context:', error);
initialParts.push({
text: '\n--- Error reading full file context ---',
});
}
}
return initialParts;
}
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
const envParts = await this.getEnvironment();
const toolRegistry = await this.config.getToolRegistry();
const toolDeclarations = toolRegistry.getFunctionDeclarations();
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
const history: Content[] = [
{
role: 'user',
parts: envParts,
},
{
role: 'model',
parts: [{ text: 'Got it. Thanks for the context!' }],
},
...(extraHistory ?? []),
];
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
const generateContentConfigWithThinking = isThinkingSupported(
this.config.getModel(),
)
? {
...this.generateContentConfig,
thinkingConfig: {
includeThoughts: true,
},
}
: this.generateContentConfig;
return new GeminiChat(
this.config,
this.getContentGenerator(),
{
systemInstruction,
...generateContentConfigWithThinking,
tools,
},
history,
);
} catch (error) {
await reportError(
error,
'Error initializing Gemini chat session.',
history,
'startChat',
);
throw new Error(`Failed to initialize chat: ${getErrorMessage(error)}`);
}
}
async *sendMessageStream(
request: PartListUnion,
signal: AbortSignal,
prompt_id: string,
turns: number = this.MAX_TURNS,
originalModel?: string,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id);
this.lastPromptId = prompt_id;
}
this.sessionTurnCount++;
if (
this.config.getMaxSessionTurns() > 0 &&
this.sessionTurnCount > this.config.getMaxSessionTurns()
) {
yield { type: GeminiEventType.MaxSessionTurns };
return new Turn(this.getChat(), prompt_id);
}
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
const boundedTurns = Math.min(turns, this.MAX_TURNS);
if (!boundedTurns) {
return new Turn(this.getChat(), prompt_id);
}
// Track the original model from the first call to detect model switching
const initialModel = originalModel || this.config.getModel();
const compressed = await this.tryCompressChat(prompt_id);
if (compressed) {
yield { type: GeminiEventType.ChatCompressed, value: compressed };
}
if (this.config.getIdeMode()) {
const openFiles = ideContext.getOpenFilesContext();
if (openFiles) {
const contextParts: string[] = [];
if (openFiles.activeFile) {
contextParts.push(
`This is the file that the user was most recently looking at:\n- Path: ${openFiles.activeFile}`,
);
if (openFiles.cursor) {
contextParts.push(
`This is the cursor position in the file:\n- Cursor Position: Line ${openFiles.cursor.line}, Character ${openFiles.cursor.character}`,
);
}
if (openFiles.selectedText) {
contextParts.push(
`This is the selected text in the active file:\n- ${openFiles.selectedText}`,
);
}
}
if (openFiles.recentOpenFiles && openFiles.recentOpenFiles.length > 0) {
const recentFiles = openFiles.recentOpenFiles
.map((file) => `- ${file.filePath}`)
.join('\n');
contextParts.push(
`Here are files the user has recently opened, with the most recent at the top:\n${recentFiles}`,
);
}
if (contextParts.length > 0) {
request = [
{ text: contextParts.join('\n') },
...(Array.isArray(request) ? request : [request]),
];
}
}
}
const turn = new Turn(this.getChat(), prompt_id);
const loopDetected = await this.loopDetector.turnStarted(signal);
if (loopDetected) {
yield { type: GeminiEventType.LoopDetected };
return turn;
}
const resultStream = turn.run(request, signal);
for await (const event of resultStream) {
if (this.loopDetector.addAndCheck(event)) {
yield { type: GeminiEventType.LoopDetected };
return turn;
}
yield event;
}
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
// Check if model was switched during the call (likely due to quota error)
const currentModel = this.config.getModel();
if (currentModel !== initialModel) {
// Model was switched (likely due to quota error fallback)
// Don't continue with recursive call to prevent unwanted Flash execution
return turn;
}
const nextSpeakerCheck = await checkNextSpeaker(
this.getChat(),
this,
signal,
);
if (nextSpeakerCheck?.next_speaker === 'model') {
logFlashDecidedToContinue(
this.config,
new FlashDecidedToContinueEvent(prompt_id),
);
const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, but the final
// turn object will be from the top-level call.
yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
boundedTurns - 1,
initialModel,
);
}
}
return turn;
}
async generateJson(
contents: Content[],
schema: SchemaUnion,
abortSignal: AbortSignal,
model?: string,
config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> {
// Use current model from config instead of hardcoded Flash model
const modelToUse =
model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = {
abortSignal,
...this.generateContentConfig,
...config,
};
const apiCall = () =>
this.getContentGenerator().generateContent(
{
model: modelToUse,
config: {
...requestConfig,
systemInstruction,
responseSchema: schema,
responseMimeType: 'application/json',
},
contents,
},
this.lastPromptId,
);
const result = await retryWithBackoff(apiCall, {
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
authType: this.config.getContentGeneratorConfig()?.authType,
});
const text = getResponseText(result);
if (!text) {
const error = new Error(
'API returned an empty response for generateJson.',
);
await reportError(
error,
'Error in generateJson: API returned an empty response.',
contents,
'generateJson-empty-response',
);
throw error;
}
try {
return JSON.parse(text);
} catch (parseError) {
await reportError(
parseError,
'Failed to parse JSON response from generateJson.',
{
responseTextFailedToParse: text,
originalRequestContents: contents,
},
'generateJson-parse',
);
throw new Error(
`Failed to parse API response as JSON: ${getErrorMessage(parseError)}`,
);
}
} catch (error) {
if (abortSignal.aborted) {
throw error;
}
// Avoid double reporting for the empty response case handled above
if (
error instanceof Error &&
error.message === 'API returned an empty response for generateJson.'
) {
throw error;
}
await reportError(
error,
'Error generating JSON content via API.',
contents,
'generateJson-api',
);
throw new Error(
`Failed to generate JSON content: ${getErrorMessage(error)}`,
);
}
}
async generateContent(
contents: Content[],
generationConfig: GenerateContentConfig,
abortSignal: AbortSignal,
model?: string,
): Promise<GenerateContentResponse> {
const modelToUse = model ?? this.config.getModel();
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
};
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
const requestConfig = {
abortSignal,
...configToUse,
systemInstruction,
};
const apiCall = () =>
this.getContentGenerator().generateContent(
{
model: modelToUse,
config: requestConfig,
contents,
},
this.lastPromptId,
);
const result = await retryWithBackoff(apiCall, {
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
authType: this.config.getContentGeneratorConfig()?.authType,
});
return result;
} catch (error: unknown) {
if (abortSignal.aborted) {
throw error;
}
await reportError(
error,
`Error generating content via API with model ${modelToUse}.`,
{
requestContents: contents,
requestConfig: configToUse,
},
'generateContent-api',
);
throw new Error(
`Failed to generate content with model ${modelToUse}: ${getErrorMessage(error)}`,
);
}
}
async generateEmbedding(texts: string[]): Promise<number[][]> {
if (!texts || texts.length === 0) {
return [];
}
const embedModelParams: EmbedContentParameters = {
model: this.embeddingModel,
contents: texts,
};
const embedContentResponse =
await this.getContentGenerator().embedContent(embedModelParams);
if (
!embedContentResponse.embeddings ||
embedContentResponse.embeddings.length === 0
) {
throw new Error('No embeddings found in API response.');
}
if (embedContentResponse.embeddings.length !== texts.length) {
throw new Error(
`API returned a mismatched number of embeddings. Expected ${texts.length}, got ${embedContentResponse.embeddings.length}.`,
);
}
return embedContentResponse.embeddings.map((embedding, index) => {
const values = embedding.values;
if (!values || values.length === 0) {
throw new Error(
`API returned an empty embedding for input text at index ${index}: "${texts[index]}"`,
);
}
return values;
});
}
async tryCompressChat(
prompt_id: string,
force: boolean = false,
): Promise<ChatCompressionInfo | null> {
const curatedHistory = this.getChat().getHistory(true);
// Regardless of `force`, don't do anything if the history is empty.
if (curatedHistory.length === 0) {
return null;
}
const model = this.config.getModel();
const { totalTokens: originalTokenCount } =
await this.getContentGenerator().countTokens({
model,
contents: curatedHistory,
});
if (originalTokenCount === undefined) {
console.warn(`Could not determine token count for model ${model}.`);
return null;
}
// Don't compress if not forced and we are under the limit.
if (
!force &&
originalTokenCount < this.COMPRESSION_TOKEN_THRESHOLD * tokenLimit(model)
) {
return null;
}
let compressBeforeIndex = findIndexAfterFraction(
curatedHistory,
1 - this.COMPRESSION_PRESERVE_THRESHOLD,
);
// Find the first user message after the index. This is the start of the next turn.
while (
compressBeforeIndex < curatedHistory.length &&
(curatedHistory[compressBeforeIndex]?.role === 'model' ||
isFunctionResponse(curatedHistory[compressBeforeIndex]))
) {
compressBeforeIndex++;
}
const historyToCompress = curatedHistory.slice(0, compressBeforeIndex);
const historyToKeep = curatedHistory.slice(compressBeforeIndex);
this.getChat().setHistory(historyToCompress);
const { text: summary } = await this.getChat().sendMessage(
{
message: {
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
},
config: {
systemInstruction: { text: getCompressionPrompt() },
},
},
prompt_id,
);
this.chat = await this.startChat([
{
role: 'user',
parts: [{ text: summary }],
},
{
role: 'model',
parts: [{ text: 'Got it. Thanks for the additional context!' }],
},
...historyToKeep,
]);
const { totalTokens: newTokenCount } =
await this.getContentGenerator().countTokens({
// model might change after calling `sendMessage`, so we get the newest value from config
model: this.config.getModel(),
contents: this.getChat().getHistory(),
});
if (newTokenCount === undefined) {
console.warn('Could not determine compressed history token count.');
return null;
}
return {
originalTokenCount,
newTokenCount,
};
}
/**
* Handles falling back to Flash model when persistent 429 errors occur for OAuth users.
* Uses a fallback handler if provided by the config; otherwise, returns null.
*/
private async handleFlashFallback(
authType?: string,
error?: unknown,
): Promise<string | null> {
// Only handle fallback for OAuth users
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
return null;
}
const currentModel = this.config.getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
if (currentModel === fallbackModel) {
return null;
}
// Check if config has a fallback handler (set by CLI package)
const fallbackHandler = this.config.flashFallbackHandler;
if (typeof fallbackHandler === 'function') {
try {
const accepted = await fallbackHandler(
currentModel,
fallbackModel,
error,
);
if (accepted !== false && accepted !== null) {
this.config.setModel(fallbackModel);
return fallbackModel;
}
// Check if the model was switched manually in the handler
if (this.config.getModel() === fallbackModel) {
return null; // Model was switched but don't continue with current prompt
}
} catch (error) {
console.warn('Flash fallback handler failed:', error);
}
}
return null;
}
}