Take the turn management out of GeminiClient (#42)
This commit is contained in:
parent
65e8e3ed1f
commit
24371a3954
|
@ -15,31 +15,17 @@ import {
|
||||||
Content,
|
Content,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { CoreSystemPrompt } from './prompts.js';
|
import { CoreSystemPrompt } from './prompts.js';
|
||||||
import {
|
|
||||||
type ToolCallEvent,
|
|
||||||
type ToolCallConfirmationDetails,
|
|
||||||
ToolCallStatus,
|
|
||||||
} from '../ui/types.js';
|
|
||||||
import process from 'node:process';
|
import process from 'node:process';
|
||||||
import { toolRegistry } from '../tools/tool-registry.js';
|
import { toolRegistry } from '../tools/tool-registry.js';
|
||||||
import { ToolResult } from '../tools/tools.js';
|
|
||||||
import { getFolderStructure } from '../utils/getFolderStructure.js';
|
import { getFolderStructure } from '../utils/getFolderStructure.js';
|
||||||
import { GeminiEventType, GeminiStream } from './gemini-stream.js';
|
import { GeminiEventType, GeminiStream } from './gemini-stream.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
|
import { Turn } from './turn.js';
|
||||||
type ToolExecutionOutcome = {
|
|
||||||
callId: string;
|
|
||||||
name: string;
|
|
||||||
args: Record<string, never>;
|
|
||||||
result?: ToolResult;
|
|
||||||
error?: Error;
|
|
||||||
confirmationDetails?: ToolCallConfirmationDetails;
|
|
||||||
};
|
|
||||||
|
|
||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
private config: Config;
|
private config: Config;
|
||||||
private ai: GoogleGenAI;
|
private ai: GoogleGenAI;
|
||||||
private defaultHyperParameters: GenerateContentConfig = {
|
private generateContentConfig: GenerateContentConfig = {
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
topP: 1,
|
topP: 1,
|
||||||
};
|
};
|
||||||
|
@ -50,14 +36,9 @@ export class GeminiClient {
|
||||||
this.ai = new GoogleGenAI({ apiKey: config.getApiKey() });
|
this.ai = new GoogleGenAI({ apiKey: config.getApiKey() });
|
||||||
}
|
}
|
||||||
|
|
||||||
async startChat(): Promise<Chat> {
|
private async getEnvironment(): Promise<Part> {
|
||||||
const tools = toolRegistry.getToolSchemas();
|
|
||||||
const model = this.config.getModel();
|
|
||||||
|
|
||||||
// --- Get environmental information ---
|
|
||||||
const cwd = process.cwd();
|
const cwd = process.cwd();
|
||||||
const today = new Date().toLocaleDateString(undefined, {
|
const today = new Date().toLocaleDateString(undefined, {
|
||||||
// Use locale-aware date formatting
|
|
||||||
weekday: 'long',
|
weekday: 'long',
|
||||||
year: 'numeric',
|
year: 'numeric',
|
||||||
month: 'long',
|
month: 'long',
|
||||||
|
@ -65,10 +46,9 @@ export class GeminiClient {
|
||||||
});
|
});
|
||||||
const platform = process.platform;
|
const platform = process.platform;
|
||||||
|
|
||||||
// --- Format information into a conversational multi-line string ---
|
|
||||||
const folderStructure = await getFolderStructure(cwd);
|
const folderStructure = await getFolderStructure(cwd);
|
||||||
// --- End folder structure formatting ---)
|
|
||||||
const initialContextText = `
|
const context = `
|
||||||
Okay, just setting up the context for our chat.
|
Okay, just setting up the context for our chat.
|
||||||
Today is ${today}.
|
Today is ${today}.
|
||||||
My operating system is: ${platform}
|
My operating system is: ${platform}
|
||||||
|
@ -76,22 +56,27 @@ I'm currently working in the directory: ${cwd}
|
||||||
${folderStructure}
|
${folderStructure}
|
||||||
`.trim();
|
`.trim();
|
||||||
|
|
||||||
const initialContextPart: Part = { text: initialContextText };
|
return { text: context };
|
||||||
// --- End environmental information formatting ---
|
}
|
||||||
|
|
||||||
|
async startChat(): Promise<Chat> {
|
||||||
|
const envPart = await this.getEnvironment();
|
||||||
|
const model = this.config.getModel();
|
||||||
|
const tools = toolRegistry.getToolSchemas();
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const chat = this.ai.chats.create({
|
const chat = this.ai.chats.create({
|
||||||
model,
|
model,
|
||||||
config: {
|
config: {
|
||||||
systemInstruction: CoreSystemPrompt,
|
systemInstruction: CoreSystemPrompt,
|
||||||
...this.defaultHyperParameters,
|
...this.generateContentConfig,
|
||||||
tools,
|
tools,
|
||||||
},
|
},
|
||||||
history: [
|
history: [
|
||||||
// --- Add the context as a single part in the initial user message ---
|
// --- Add the context as a single part in the initial user message ---
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: [initialContextPart], // Pass the single Part object in an array
|
parts: [envPart], // Pass the single Part object in an array
|
||||||
},
|
},
|
||||||
// --- Add an empty model response to balance the history ---
|
// --- Add an empty model response to balance the history ---
|
||||||
{
|
{
|
||||||
|
@ -109,221 +94,32 @@ ${folderStructure}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
addMessageToHistory(chat: Chat, message: Content): void {
|
|
||||||
const history = chat.getHistory();
|
|
||||||
history.push(message);
|
|
||||||
}
|
|
||||||
|
|
||||||
async *sendMessageStream(
|
async *sendMessageStream(
|
||||||
chat: Chat,
|
chat: Chat,
|
||||||
request: PartListUnion,
|
request: PartListUnion,
|
||||||
signal?: AbortSignal,
|
signal?: AbortSignal,
|
||||||
): GeminiStream {
|
): GeminiStream {
|
||||||
let currentMessageToSend: PartListUnion = request;
|
|
||||||
let turns = 0;
|
let turns = 0;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while (turns < this.MAX_TURNS) {
|
while (turns < this.MAX_TURNS) {
|
||||||
turns++;
|
turns++;
|
||||||
const resultStream = await chat.sendMessageStream({
|
// A turn either yields a text response or returns
|
||||||
message: currentMessageToSend,
|
// function responses to be used in the next turn.
|
||||||
});
|
// This callsite is responsible to handle the buffered
|
||||||
let functionResponseParts: Part[] = [];
|
// function responses and use it on the next turn.
|
||||||
let pendingToolCalls: Array<{
|
const turn = new Turn(chat);
|
||||||
callId: string;
|
const resultStream = turn.run(request, signal);
|
||||||
name: string;
|
|
||||||
args: Record<string, never>;
|
|
||||||
}> = [];
|
|
||||||
let yieldedTextInTurn = false;
|
|
||||||
const chunksForDebug = [];
|
|
||||||
|
|
||||||
for await (const chunk of resultStream) {
|
for await (const event of resultStream) {
|
||||||
chunksForDebug.push(chunk);
|
yield event;
|
||||||
if (signal?.aborted) {
|
}
|
||||||
const abortError = new Error(
|
const fnResponses = turn.getFunctionResponses();
|
||||||
'Request cancelled by user during stream.',
|
if (fnResponses.length > 0) {
|
||||||
);
|
request = fnResponses;
|
||||||
abortError.name = 'AbortError';
|
continue; // use the responses in the next turn
|
||||||
throw abortError;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const functionCalls = chunk.functionCalls;
|
|
||||||
if (functionCalls && functionCalls.length > 0) {
|
|
||||||
for (const call of functionCalls) {
|
|
||||||
const callId =
|
|
||||||
call.id ??
|
|
||||||
`${call.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
|
|
||||||
const name = call.name || 'undefined_tool_name';
|
|
||||||
const args = (call.args || {}) as Record<string, never>;
|
|
||||||
|
|
||||||
pendingToolCalls.push({ callId, name, args });
|
|
||||||
const evtValue: ToolCallEvent = {
|
|
||||||
type: 'tool_call',
|
|
||||||
status: ToolCallStatus.Pending,
|
|
||||||
callId,
|
|
||||||
name,
|
|
||||||
args,
|
|
||||||
resultDisplay: undefined,
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
};
|
|
||||||
yield {
|
|
||||||
type: GeminiEventType.ToolCallInfo,
|
|
||||||
value: evtValue,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const text = chunk.text;
|
|
||||||
if (text) {
|
|
||||||
yieldedTextInTurn = true;
|
|
||||||
yield {
|
|
||||||
type: GeminiEventType.Content,
|
|
||||||
value: text,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pendingToolCalls.length > 0) {
|
|
||||||
const toolPromises: Array<Promise<ToolExecutionOutcome>> =
|
|
||||||
pendingToolCalls.map(async (pendingToolCall) => {
|
|
||||||
const tool = toolRegistry.getTool(pendingToolCall.name);
|
|
||||||
|
|
||||||
if (!tool) {
|
|
||||||
// Directly return error outcome if tool not found
|
|
||||||
return {
|
|
||||||
...pendingToolCall,
|
|
||||||
error: new Error(
|
|
||||||
`Tool "${pendingToolCall.name}" not found or is not registered.`,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const confirmation = await tool.shouldConfirmExecute(
|
|
||||||
pendingToolCall.args,
|
|
||||||
);
|
|
||||||
if (confirmation) {
|
|
||||||
return {
|
|
||||||
...pendingToolCall,
|
|
||||||
confirmationDetails: confirmation,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
return {
|
|
||||||
...pendingToolCall,
|
|
||||||
error: new Error(
|
|
||||||
`Tool failed to check tool confirmation: ${error}`,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const result = await tool.execute(pendingToolCall.args);
|
|
||||||
return { ...pendingToolCall, result };
|
|
||||||
} catch (error) {
|
|
||||||
return {
|
|
||||||
...pendingToolCall,
|
|
||||||
error: new Error(`Tool failed to execute: ${error}`),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
});
|
|
||||||
const toolExecutionOutcomes: ToolExecutionOutcome[] =
|
|
||||||
await Promise.all(toolPromises);
|
|
||||||
|
|
||||||
for (const executedTool of toolExecutionOutcomes) {
|
|
||||||
const { callId, name, args, result, error, confirmationDetails } =
|
|
||||||
executedTool;
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
const errorMessage = error?.message || String(error);
|
|
||||||
yield {
|
|
||||||
type: GeminiEventType.Content,
|
|
||||||
value: `[Error invoking tool ${name}: ${errorMessage}]`,
|
|
||||||
};
|
|
||||||
} else if (
|
|
||||||
result &&
|
|
||||||
typeof result === 'object' &&
|
|
||||||
result !== null &&
|
|
||||||
'error' in result
|
|
||||||
) {
|
|
||||||
const errorMessage = String(result.error);
|
|
||||||
yield {
|
|
||||||
type: GeminiEventType.Content,
|
|
||||||
value: `[Error executing tool ${name}: ${errorMessage}]`,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
const status = confirmationDetails
|
|
||||||
? ToolCallStatus.Confirming
|
|
||||||
: ToolCallStatus.Invoked;
|
|
||||||
const evtValue: ToolCallEvent = {
|
|
||||||
type: 'tool_call',
|
|
||||||
status,
|
|
||||||
callId,
|
|
||||||
name,
|
|
||||||
args,
|
|
||||||
resultDisplay: result?.returnDisplay,
|
|
||||||
confirmationDetails,
|
|
||||||
};
|
|
||||||
yield {
|
|
||||||
type: GeminiEventType.ToolCallInfo,
|
|
||||||
value: evtValue,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pendingToolCalls = [];
|
|
||||||
|
|
||||||
const waitingOnConfirmations =
|
|
||||||
toolExecutionOutcomes.filter(
|
|
||||||
(outcome) => outcome.confirmationDetails,
|
|
||||||
).length > 0;
|
|
||||||
if (waitingOnConfirmations) {
|
|
||||||
// Stop processing content, wait for user.
|
|
||||||
// TODO: Kill token processing once API supports signals.
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
functionResponseParts = toolExecutionOutcomes.map(
|
|
||||||
(executedTool: ToolExecutionOutcome): Part => {
|
|
||||||
const { name, result, error } = executedTool;
|
|
||||||
const output = { output: result?.llmContent };
|
|
||||||
let toolOutcomePayload: Record<string, unknown>;
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
const errorMessage = error?.message || String(error);
|
|
||||||
toolOutcomePayload = {
|
|
||||||
error: `Invocation failed: ${errorMessage}`,
|
|
||||||
};
|
|
||||||
console.error(
|
|
||||||
`[Turn ${turns}] Critical error invoking tool ${name}:`,
|
|
||||||
error,
|
|
||||||
);
|
|
||||||
} else if (
|
|
||||||
result &&
|
|
||||||
typeof result === 'object' &&
|
|
||||||
result !== null &&
|
|
||||||
'error' in result
|
|
||||||
) {
|
|
||||||
toolOutcomePayload = output;
|
|
||||||
console.warn(
|
|
||||||
`[Turn ${turns}] Tool ${name} returned an error structure:`,
|
|
||||||
result.error,
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
toolOutcomePayload = output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
functionResponse: {
|
|
||||||
name,
|
|
||||||
id: executedTool.callId,
|
|
||||||
response: toolOutcomePayload,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
},
|
|
||||||
);
|
|
||||||
currentMessageToSend = functionResponseParts;
|
|
||||||
} else if (yieldedTextInTurn) {
|
|
||||||
const history = chat.getHistory();
|
const history = chat.getHistory();
|
||||||
const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
|
const checkPrompt = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
|
||||||
|
|
||||||
|
@ -394,7 +190,7 @@ Respond *only* in JSON format according to the following schema. Do not include
|
||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
if (nextSpeaker === 'model') {
|
if (nextSpeaker === 'model') {
|
||||||
currentMessageToSend = { text: 'alright' }; // Or potentially a more meaningful continuation prompt
|
request = { text: 'alright' }; // Or potentially a more meaningful continuation prompt
|
||||||
} else {
|
} else {
|
||||||
// 'user' should speak next, or value is missing/invalid. End the turn.
|
// 'user' should speak next, or value is missing/invalid. End the turn.
|
||||||
break;
|
break;
|
||||||
|
@ -407,12 +203,6 @@ Respond *only* in JSON format according to the following schema. Do not include
|
||||||
// If the check fails, assume user should speak next to avoid infinite loops
|
// If the check fails, assume user should speak next to avoid infinite loops
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
console.warn(
|
|
||||||
`[Turn ${turns}] No text or function calls received from Gemini. Ending interaction.`,
|
|
||||||
);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (turns >= this.MAX_TURNS) {
|
if (turns >= this.MAX_TURNS) {
|
||||||
|
@ -426,6 +216,8 @@ Respond *only* in JSON format according to the following schema. Do not include
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
} catch (error: unknown) {
|
} catch (error: unknown) {
|
||||||
|
// TODO(jbd): There is so much of packing/unpacking of error types.
|
||||||
|
// Figure out a way to remove the redundant work.
|
||||||
if (error instanceof Error && error.name === 'AbortError') {
|
if (error instanceof Error && error.name === 'AbortError') {
|
||||||
console.log('Gemini stream request aborted by user.');
|
console.log('Gemini stream request aborted by user.');
|
||||||
throw error;
|
throw error;
|
||||||
|
@ -457,7 +249,7 @@ Respond *only* in JSON format according to the following schema. Do not include
|
||||||
const result = await this.ai.models.generateContent({
|
const result = await this.ai.models.generateContent({
|
||||||
model,
|
model,
|
||||||
config: {
|
config: {
|
||||||
...this.defaultHyperParameters,
|
...this.generateContentConfig,
|
||||||
systemInstruction: CoreSystemPrompt,
|
systemInstruction: CoreSystemPrompt,
|
||||||
responseSchema: schema,
|
responseSchema: schema,
|
||||||
responseMimeType: 'application/json',
|
responseMimeType: 'application/json',
|
||||||
|
|
|
@ -0,0 +1,233 @@
|
||||||
|
import {
|
||||||
|
Part,
|
||||||
|
Chat,
|
||||||
|
PartListUnion,
|
||||||
|
GenerateContentResponse,
|
||||||
|
FunctionCall,
|
||||||
|
} from '@google/genai';
|
||||||
|
import {
|
||||||
|
type ToolCallConfirmationDetails,
|
||||||
|
ToolCallStatus,
|
||||||
|
ToolCallEvent,
|
||||||
|
} from '../ui/types.js';
|
||||||
|
import { ToolResult } from '../tools/tools.js';
|
||||||
|
import { toolRegistry } from '../tools/tool-registry.js';
|
||||||
|
import { GeminiEventType, GeminiStream } from './gemini-stream.js';
|
||||||
|
|
||||||
|
export type ToolExecutionOutcome = {
|
||||||
|
callId: string;
|
||||||
|
name: string;
|
||||||
|
args: Record<string, never>;
|
||||||
|
result?: ToolResult;
|
||||||
|
error?: Error;
|
||||||
|
confirmationDetails?: ToolCallConfirmationDetails;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(jbd): Move ToolExecutionOutcome to somewhere else?
|
||||||
|
|
||||||
|
// A turn manages the agentic loop turn.
|
||||||
|
// Turn.run emits throught the turn events that could be used
|
||||||
|
// as immediate feedback to the user.
|
||||||
|
export class Turn {
|
||||||
|
private readonly chat: Chat;
|
||||||
|
private pendingToolCalls: Array<{
|
||||||
|
callId: string;
|
||||||
|
name: string;
|
||||||
|
args: Record<string, never>;
|
||||||
|
}>;
|
||||||
|
private fnResponses: Part[];
|
||||||
|
private debugResponses: GenerateContentResponse[];
|
||||||
|
|
||||||
|
constructor(chat: Chat) {
|
||||||
|
this.chat = chat;
|
||||||
|
this.pendingToolCalls = [];
|
||||||
|
this.fnResponses = [];
|
||||||
|
this.debugResponses = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
async *run(req: PartListUnion, signal?: AbortSignal): GeminiStream {
|
||||||
|
const responseStream = await this.chat.sendMessageStream({
|
||||||
|
message: req,
|
||||||
|
});
|
||||||
|
for await (const resp of responseStream) {
|
||||||
|
this.debugResponses.push(resp);
|
||||||
|
if (signal?.aborted) {
|
||||||
|
throw this.abortError();
|
||||||
|
}
|
||||||
|
if (resp.text) {
|
||||||
|
yield {
|
||||||
|
type: GeminiEventType.Content,
|
||||||
|
value: resp.text,
|
||||||
|
};
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!resp.functionCalls) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (const fnCall of resp.functionCalls) {
|
||||||
|
for await (const event of this.handlePendingFunctionCall(fnCall)) {
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create promises to be able to wait for executions to complete.
|
||||||
|
const toolPromises = this.pendingToolCalls.map(
|
||||||
|
async (pendingToolCall) => {
|
||||||
|
const tool = toolRegistry.getTool(pendingToolCall.name);
|
||||||
|
if (!tool) {
|
||||||
|
return {
|
||||||
|
...pendingToolCall,
|
||||||
|
error: new Error(
|
||||||
|
`Tool "${pendingToolCall.name}" not found or is not registered.`,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
const shouldConfirm = await tool.shouldConfirmExecute(
|
||||||
|
pendingToolCall.args,
|
||||||
|
);
|
||||||
|
if (shouldConfirm) {
|
||||||
|
return {
|
||||||
|
// TODO(jbd): Should confirm isn't confirmation details.
|
||||||
|
...pendingToolCall,
|
||||||
|
confirmationDetails: shouldConfirm,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
const result = await tool.execute(pendingToolCall.args);
|
||||||
|
return { ...pendingToolCall, result };
|
||||||
|
},
|
||||||
|
);
|
||||||
|
const outcomes = await Promise.all(toolPromises);
|
||||||
|
for await (const event of this.handleToolOutcomes(outcomes)) {
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
this.pendingToolCalls = [];
|
||||||
|
|
||||||
|
// TODO(jbd): Make it harder for the caller to ignore the
|
||||||
|
// buffered function responses.
|
||||||
|
this.fnResponses = this.buildFunctionResponses(outcomes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async *handlePendingFunctionCall(fnCall: FunctionCall): GeminiStream {
|
||||||
|
const callId =
|
||||||
|
fnCall.id ??
|
||||||
|
`${fnCall.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
|
||||||
|
// TODO(jbd): replace with uuid.
|
||||||
|
const name = fnCall.name || 'undefined_tool_name';
|
||||||
|
const args = (fnCall.args || {}) as Record<string, never>;
|
||||||
|
|
||||||
|
this.pendingToolCalls.push({ callId, name, args });
|
||||||
|
const value: ToolCallEvent = {
|
||||||
|
type: 'tool_call',
|
||||||
|
status: ToolCallStatus.Pending,
|
||||||
|
callId,
|
||||||
|
name,
|
||||||
|
args,
|
||||||
|
resultDisplay: undefined,
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
yield {
|
||||||
|
type: GeminiEventType.ToolCallInfo,
|
||||||
|
value,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private async *handleToolOutcomes(
|
||||||
|
outcomes: ToolExecutionOutcome[],
|
||||||
|
): GeminiStream {
|
||||||
|
for (const outcome of outcomes) {
|
||||||
|
const { callId, name, args, result, error, confirmationDetails } =
|
||||||
|
outcome;
|
||||||
|
if (error) {
|
||||||
|
// TODO(jbd): Error handling needs a cleanup.
|
||||||
|
const errorMessage = error?.message || String(error);
|
||||||
|
yield {
|
||||||
|
type: GeminiEventType.Content,
|
||||||
|
value: `[Error invoking tool ${name}: ${errorMessage}]`,
|
||||||
|
};
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
result &&
|
||||||
|
typeof result === 'object' &&
|
||||||
|
result !== null &&
|
||||||
|
'error' in result
|
||||||
|
) {
|
||||||
|
const errorMessage = String(result.error);
|
||||||
|
yield {
|
||||||
|
type: GeminiEventType.Content,
|
||||||
|
value: `[Error executing tool ${name}: ${errorMessage}]`,
|
||||||
|
};
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const status = confirmationDetails
|
||||||
|
? ToolCallStatus.Confirming
|
||||||
|
: ToolCallStatus.Invoked;
|
||||||
|
const value: ToolCallEvent = {
|
||||||
|
type: 'tool_call',
|
||||||
|
status,
|
||||||
|
callId,
|
||||||
|
name,
|
||||||
|
args,
|
||||||
|
resultDisplay: result?.returnDisplay,
|
||||||
|
confirmationDetails,
|
||||||
|
};
|
||||||
|
yield {
|
||||||
|
type: GeminiEventType.ToolCallInfo,
|
||||||
|
value,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private buildFunctionResponses(outcomes: ToolExecutionOutcome[]): Part[] {
|
||||||
|
return outcomes.map((outcome: ToolExecutionOutcome): Part => {
|
||||||
|
const { name, result, error } = outcome;
|
||||||
|
const output = { output: result?.llmContent };
|
||||||
|
let fnResponse: Record<string, unknown>;
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
const errorMessage = error?.message || String(error);
|
||||||
|
fnResponse = {
|
||||||
|
error: `Invocation failed: ${errorMessage}`,
|
||||||
|
};
|
||||||
|
console.error(`[Turn] Critical error invoking tool ${name}:`, error);
|
||||||
|
} else if (
|
||||||
|
result &&
|
||||||
|
typeof result === 'object' &&
|
||||||
|
result !== null &&
|
||||||
|
'error' in result
|
||||||
|
) {
|
||||||
|
fnResponse = output;
|
||||||
|
console.warn(
|
||||||
|
`[Turn] Tool ${name} returned an error structure:`,
|
||||||
|
result.error,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
fnResponse = output;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
functionResponse: {
|
||||||
|
name,
|
||||||
|
id: outcome.callId,
|
||||||
|
response: fnResponse,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private abortError(): Error {
|
||||||
|
// TODO(jbd): Move it out of this class.
|
||||||
|
const error = new Error('Request cancelled by user during stream.');
|
||||||
|
error.name = 'AbortError';
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
getFunctionResponses(): Part[] {
|
||||||
|
return this.fnResponses;
|
||||||
|
}
|
||||||
|
|
||||||
|
getDebugResponses(): GenerateContentResponse[] {
|
||||||
|
return this.debugResponses;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue