Take the turn management out of GeminiClient (#42)

This commit is contained in:
Jaana Dogan 2025-04-18 23:11:33 -07:00 committed by GitHub
parent 65e8e3ed1f
commit 24371a3954
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 335 additions and 310 deletions

View File

@ -15,31 +15,17 @@ import {
Content, Content,
} from '@google/genai'; } from '@google/genai';
import { CoreSystemPrompt } from './prompts.js'; import { CoreSystemPrompt } from './prompts.js';
import {
type ToolCallEvent,
type ToolCallConfirmationDetails,
ToolCallStatus,
} from '../ui/types.js';
import process from 'node:process'; import process from 'node:process';
import { toolRegistry } from '../tools/tool-registry.js'; import { toolRegistry } from '../tools/tool-registry.js';
import { ToolResult } from '../tools/tools.js';
import { getFolderStructure } from '../utils/getFolderStructure.js'; import { getFolderStructure } from '../utils/getFolderStructure.js';
import { GeminiEventType, GeminiStream } from './gemini-stream.js'; import { GeminiEventType, GeminiStream } from './gemini-stream.js';
import { Config } from '../config/config.js'; import { Config } from '../config/config.js';
import { Turn } from './turn.js';
type ToolExecutionOutcome = {
callId: string;
name: string;
args: Record<string, never>;
result?: ToolResult;
error?: Error;
confirmationDetails?: ToolCallConfirmationDetails;
};
export class GeminiClient { export class GeminiClient {
private config: Config; private config: Config;
private ai: GoogleGenAI; private ai: GoogleGenAI;
private defaultHyperParameters: GenerateContentConfig = { private generateContentConfig: GenerateContentConfig = {
temperature: 0, temperature: 0,
topP: 1, topP: 1,
}; };
@ -50,14 +36,9 @@ export class GeminiClient {
this.ai = new GoogleGenAI({ apiKey: config.getApiKey() }); this.ai = new GoogleGenAI({ apiKey: config.getApiKey() });
} }
async startChat(): Promise<Chat> { private async getEnvironment(): Promise<Part> {
const tools = toolRegistry.getToolSchemas();
const model = this.config.getModel();
// --- Get environmental information ---
const cwd = process.cwd(); const cwd = process.cwd();
const today = new Date().toLocaleDateString(undefined, { const today = new Date().toLocaleDateString(undefined, {
// Use locale-aware date formatting
weekday: 'long', weekday: 'long',
year: 'numeric', year: 'numeric',
month: 'long', month: 'long',
@ -65,10 +46,9 @@ export class GeminiClient {
}); });
const platform = process.platform; const platform = process.platform;
// --- Format information into a conversational multi-line string ---
const folderStructure = await getFolderStructure(cwd); const folderStructure = await getFolderStructure(cwd);
// --- End folder structure formatting ---)
const initialContextText = ` const context = `
Okay, just setting up the context for our chat. Okay, just setting up the context for our chat.
Today is ${today}. Today is ${today}.
My operating system is: ${platform} My operating system is: ${platform}
@ -76,22 +56,27 @@ I'm currently working in the directory: ${cwd}
${folderStructure} ${folderStructure}
`.trim(); `.trim();
const initialContextPart: Part = { text: initialContextText }; return { text: context };
// --- End environmental information formatting --- }
async startChat(): Promise<Chat> {
const envPart = await this.getEnvironment();
const model = this.config.getModel();
const tools = toolRegistry.getToolSchemas();
try { try {
const chat = this.ai.chats.create({ const chat = this.ai.chats.create({
model, model,
config: { config: {
systemInstruction: CoreSystemPrompt, systemInstruction: CoreSystemPrompt,
...this.defaultHyperParameters, ...this.generateContentConfig,
tools, tools,
}, },
history: [ history: [
// --- Add the context as a single part in the initial user message --- // --- Add the context as a single part in the initial user message ---
{ {
role: 'user', role: 'user',
parts: [initialContextPart], // Pass the single Part object in an array parts: [envPart], // Pass the single Part object in an array
}, },
// --- Add an empty model response to balance the history --- // --- Add an empty model response to balance the history ---
{ {
@ -109,221 +94,32 @@ ${folderStructure}
} }
} }
addMessageToHistory(chat: Chat, message: Content): void {
const history = chat.getHistory();
history.push(message);
}
async *sendMessageStream( async *sendMessageStream(
chat: Chat, chat: Chat,
request: PartListUnion, request: PartListUnion,
signal?: AbortSignal, signal?: AbortSignal,
): GeminiStream { ): GeminiStream {
let currentMessageToSend: PartListUnion = request;
let turns = 0; let turns = 0;
try { try {
while (turns < this.MAX_TURNS) { while (turns < this.MAX_TURNS) {
turns++; turns++;
const resultStream = await chat.sendMessageStream({ // A turn either yields a text response or returns
message: currentMessageToSend, // function responses to be used in the next turn.
}); // This callsite is responsible to handle the buffered
let functionResponseParts: Part[] = []; // function responses and use it on the next turn.
let pendingToolCalls: Array<{ const turn = new Turn(chat);
callId: string; const resultStream = turn.run(request, signal);
name: string;
args: Record<string, never>;
}> = [];
let yieldedTextInTurn = false;
const chunksForDebug = [];
for await (const chunk of resultStream) { for await (const event of resultStream) {
chunksForDebug.push(chunk); yield event;
if (signal?.aborted) { }
const abortError = new Error( const fnResponses = turn.getFunctionResponses();
'Request cancelled by user during stream.', if (fnResponses.length > 0) {
); request = fnResponses;
abortError.name = 'AbortError'; continue; // use the responses in the next turn
throw abortError;
} }
const functionCalls = chunk.functionCalls;
if (functionCalls && functionCalls.length > 0) {
for (const call of functionCalls) {
const callId =
call.id ??
`${call.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
const name = call.name || 'undefined_tool_name';
const args = (call.args || {}) as Record<string, never>;
pendingToolCalls.push({ callId, name, args });
const evtValue: ToolCallEvent = {
type: 'tool_call',
status: ToolCallStatus.Pending,
callId,
name,
args,
resultDisplay: undefined,
confirmationDetails: undefined,
};
yield {
type: GeminiEventType.ToolCallInfo,
value: evtValue,
};
}
} else {
const text = chunk.text;
if (text) {
yieldedTextInTurn = true;
yield {
type: GeminiEventType.Content,
value: text,
};
}
}
}
if (pendingToolCalls.length > 0) {
const toolPromises: Array<Promise<ToolExecutionOutcome>> =
pendingToolCalls.map(async (pendingToolCall) => {
const tool = toolRegistry.getTool(pendingToolCall.name);
if (!tool) {
// Directly return error outcome if tool not found
return {
...pendingToolCall,
error: new Error(
`Tool "${pendingToolCall.name}" not found or is not registered.`,
),
};
}
try {
const confirmation = await tool.shouldConfirmExecute(
pendingToolCall.args,
);
if (confirmation) {
return {
...pendingToolCall,
confirmationDetails: confirmation,
};
}
} catch (error) {
return {
...pendingToolCall,
error: new Error(
`Tool failed to check tool confirmation: ${error}`,
),
};
}
try {
const result = await tool.execute(pendingToolCall.args);
return { ...pendingToolCall, result };
} catch (error) {
return {
...pendingToolCall,
error: new Error(`Tool failed to execute: ${error}`),
};
}
});
const toolExecutionOutcomes: ToolExecutionOutcome[] =
await Promise.all(toolPromises);
for (const executedTool of toolExecutionOutcomes) {
const { callId, name, args, result, error, confirmationDetails } =
executedTool;
if (error) {
const errorMessage = error?.message || String(error);
yield {
type: GeminiEventType.Content,
value: `[Error invoking tool ${name}: ${errorMessage}]`,
};
} else if (
result &&
typeof result === 'object' &&
result !== null &&
'error' in result
) {
const errorMessage = String(result.error);
yield {
type: GeminiEventType.Content,
value: `[Error executing tool ${name}: ${errorMessage}]`,
};
} else {
const status = confirmationDetails
? ToolCallStatus.Confirming
: ToolCallStatus.Invoked;
const evtValue: ToolCallEvent = {
type: 'tool_call',
status,
callId,
name,
args,
resultDisplay: result?.returnDisplay,
confirmationDetails,
};
yield {
type: GeminiEventType.ToolCallInfo,
value: evtValue,
};
}
}
pendingToolCalls = [];
const waitingOnConfirmations =
toolExecutionOutcomes.filter(
(outcome) => outcome.confirmationDetails,
).length > 0;
if (waitingOnConfirmations) {
// Stop processing content, wait for user.
// TODO: Kill token processing once API supports signals.
break;
}
functionResponseParts = toolExecutionOutcomes.map(
(executedTool: ToolExecutionOutcome): Part => {
const { name, result, error } = executedTool;
const output = { output: result?.llmContent };
let toolOutcomePayload: Record<string, unknown>;
if (error) {
const errorMessage = error?.message || String(error);
toolOutcomePayload = {
error: `Invocation failed: ${errorMessage}`,
};
console.error(
`[Turn ${turns}] Critical error invoking tool ${name}:`,
error,
);
} else if (
result &&
typeof result === 'object' &&
result !== null &&
'error' in result
) {
toolOutcomePayload = output;
console.warn(
`[Turn ${turns}] Tool ${name} returned an error structure:`,
result.error,
);
} else {
toolOutcomePayload = output;
}
return {
functionResponse: {
name,
id: executedTool.callId,
response: toolOutcomePayload,
},
};
},
);
currentMessageToSend = functionResponseParts;
} else if (yieldedTextInTurn) {
const history = chat.getHistory(); const history = chat.getHistory();
const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you). const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
@ -394,7 +190,7 @@ Respond *only* in JSON format according to the following schema. Do not include
: undefined; : undefined;
if (nextSpeaker === 'model') { if (nextSpeaker === 'model') {
currentMessageToSend = { text: 'alright' }; // Or potentially a more meaningful continuation prompt request = { text: 'alright' }; // Or potentially a more meaningful continuation prompt
} else { } else {
// 'user' should speak next, or value is missing/invalid. End the turn. // 'user' should speak next, or value is missing/invalid. End the turn.
break; break;
@ -407,12 +203,6 @@ Respond *only* in JSON format according to the following schema. Do not include
// If the check fails, assume user should speak next to avoid infinite loops // If the check fails, assume user should speak next to avoid infinite loops
break; break;
} }
} else {
console.warn(
`[Turn ${turns}] No text or function calls received from Gemini. Ending interaction.`,
);
break;
}
} }
if (turns >= this.MAX_TURNS) { if (turns >= this.MAX_TURNS) {
@ -426,6 +216,8 @@ Respond *only* in JSON format according to the following schema. Do not include
}; };
} }
} catch (error: unknown) { } catch (error: unknown) {
// TODO(jbd): There is so much of packing/unpacking of error types.
// Figure out a way to remove the redundant work.
if (error instanceof Error && error.name === 'AbortError') { if (error instanceof Error && error.name === 'AbortError') {
console.log('Gemini stream request aborted by user.'); console.log('Gemini stream request aborted by user.');
throw error; throw error;
@ -457,7 +249,7 @@ Respond *only* in JSON format according to the following schema. Do not include
const result = await this.ai.models.generateContent({ const result = await this.ai.models.generateContent({
model, model,
config: { config: {
...this.defaultHyperParameters, ...this.generateContentConfig,
systemInstruction: CoreSystemPrompt, systemInstruction: CoreSystemPrompt,
responseSchema: schema, responseSchema: schema,
responseMimeType: 'application/json', responseMimeType: 'application/json',

View File

@ -0,0 +1,233 @@
import {
Part,
Chat,
PartListUnion,
GenerateContentResponse,
FunctionCall,
} from '@google/genai';
import {
type ToolCallConfirmationDetails,
ToolCallStatus,
ToolCallEvent,
} from '../ui/types.js';
import { ToolResult } from '../tools/tools.js';
import { toolRegistry } from '../tools/tool-registry.js';
import { GeminiEventType, GeminiStream } from './gemini-stream.js';
export type ToolExecutionOutcome = {
callId: string;
name: string;
args: Record<string, never>;
result?: ToolResult;
error?: Error;
confirmationDetails?: ToolCallConfirmationDetails;
};
// TODO(jbd): Move ToolExecutionOutcome to somewhere else?
// A turn manages the agentic loop turn.
// Turn.run emits throught the turn events that could be used
// as immediate feedback to the user.
export class Turn {
private readonly chat: Chat;
private pendingToolCalls: Array<{
callId: string;
name: string;
args: Record<string, never>;
}>;
private fnResponses: Part[];
private debugResponses: GenerateContentResponse[];
constructor(chat: Chat) {
this.chat = chat;
this.pendingToolCalls = [];
this.fnResponses = [];
this.debugResponses = [];
}
async *run(req: PartListUnion, signal?: AbortSignal): GeminiStream {
const responseStream = await this.chat.sendMessageStream({
message: req,
});
for await (const resp of responseStream) {
this.debugResponses.push(resp);
if (signal?.aborted) {
throw this.abortError();
}
if (resp.text) {
yield {
type: GeminiEventType.Content,
value: resp.text,
};
continue;
}
if (!resp.functionCalls) {
continue;
}
for (const fnCall of resp.functionCalls) {
for await (const event of this.handlePendingFunctionCall(fnCall)) {
yield event;
}
}
// Create promises to be able to wait for executions to complete.
const toolPromises = this.pendingToolCalls.map(
async (pendingToolCall) => {
const tool = toolRegistry.getTool(pendingToolCall.name);
if (!tool) {
return {
...pendingToolCall,
error: new Error(
`Tool "${pendingToolCall.name}" not found or is not registered.`,
),
};
}
const shouldConfirm = await tool.shouldConfirmExecute(
pendingToolCall.args,
);
if (shouldConfirm) {
return {
// TODO(jbd): Should confirm isn't confirmation details.
...pendingToolCall,
confirmationDetails: shouldConfirm,
};
}
const result = await tool.execute(pendingToolCall.args);
return { ...pendingToolCall, result };
},
);
const outcomes = await Promise.all(toolPromises);
for await (const event of this.handleToolOutcomes(outcomes)) {
yield event;
}
this.pendingToolCalls = [];
// TODO(jbd): Make it harder for the caller to ignore the
// buffered function responses.
this.fnResponses = this.buildFunctionResponses(outcomes);
}
}
private async *handlePendingFunctionCall(fnCall: FunctionCall): GeminiStream {
const callId =
fnCall.id ??
`${fnCall.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
// TODO(jbd): replace with uuid.
const name = fnCall.name || 'undefined_tool_name';
const args = (fnCall.args || {}) as Record<string, never>;
this.pendingToolCalls.push({ callId, name, args });
const value: ToolCallEvent = {
type: 'tool_call',
status: ToolCallStatus.Pending,
callId,
name,
args,
resultDisplay: undefined,
confirmationDetails: undefined,
};
yield {
type: GeminiEventType.ToolCallInfo,
value,
};
}
private async *handleToolOutcomes(
outcomes: ToolExecutionOutcome[],
): GeminiStream {
for (const outcome of outcomes) {
const { callId, name, args, result, error, confirmationDetails } =
outcome;
if (error) {
// TODO(jbd): Error handling needs a cleanup.
const errorMessage = error?.message || String(error);
yield {
type: GeminiEventType.Content,
value: `[Error invoking tool ${name}: ${errorMessage}]`,
};
return;
}
if (
result &&
typeof result === 'object' &&
result !== null &&
'error' in result
) {
const errorMessage = String(result.error);
yield {
type: GeminiEventType.Content,
value: `[Error executing tool ${name}: ${errorMessage}]`,
};
return;
}
const status = confirmationDetails
? ToolCallStatus.Confirming
: ToolCallStatus.Invoked;
const value: ToolCallEvent = {
type: 'tool_call',
status,
callId,
name,
args,
resultDisplay: result?.returnDisplay,
confirmationDetails,
};
yield {
type: GeminiEventType.ToolCallInfo,
value,
};
}
}
private buildFunctionResponses(outcomes: ToolExecutionOutcome[]): Part[] {
return outcomes.map((outcome: ToolExecutionOutcome): Part => {
const { name, result, error } = outcome;
const output = { output: result?.llmContent };
let fnResponse: Record<string, unknown>;
if (error) {
const errorMessage = error?.message || String(error);
fnResponse = {
error: `Invocation failed: ${errorMessage}`,
};
console.error(`[Turn] Critical error invoking tool ${name}:`, error);
} else if (
result &&
typeof result === 'object' &&
result !== null &&
'error' in result
) {
fnResponse = output;
console.warn(
`[Turn] Tool ${name} returned an error structure:`,
result.error,
);
} else {
fnResponse = output;
}
return {
functionResponse: {
name,
id: outcome.callId,
response: fnResponse,
},
};
});
}
private abortError(): Error {
// TODO(jbd): Move it out of this class.
const error = new Error('Request cancelled by user during stream.');
error.name = 'AbortError';
throw error;
}
getFunctionResponses(): Part[] {
return this.fnResponses;
}
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}
}