gemini-cli/packages/server/src/core/turn.ts

268 lines
8.1 KiB
TypeScript

/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
Part,
Chat,
PartListUnion,
GenerateContentResponse,
FunctionCall,
FunctionDeclaration,
} from '@google/genai';
// Removed UI type imports
import {
ToolCallConfirmationDetails,
ToolResult,
ToolResultDisplay,
} from '../tools/tools.js'; // Keep ToolResult for now
// Removed gemini-stream import (types defined locally)
// --- Types for Server Logic ---
// Define a simpler structure for Tool execution results within the server
interface ServerToolExecutionOutcome {
callId: string;
name: string;
args: Record<string, unknown>; // Use unknown for broader compatibility
result?: ToolResult;
error?: Error;
confirmationDetails: ToolCallConfirmationDetails | undefined;
}
// Define a structure for tools passed to the server
export interface ServerTool {
name: string;
schema: FunctionDeclaration; // Schema is needed
// The execute method signature might differ slightly or be wrapped
execute(params: Record<string, unknown>): Promise<ToolResult>;
shouldConfirmExecute(
params: Record<string, unknown>,
): Promise<ToolCallConfirmationDetails | false>;
// validation and description might be handled differently or passed
}
// Redefine necessary event types locally
export enum GeminiEventType {
Content = 'content',
ToolCallRequest = 'tool_call_request',
ToolCallResponse = 'tool_call_response',
ToolCallConfirmation = 'tool_call_confirmation',
}
export interface ToolCallRequestInfo {
callId: string;
name: string;
args: Record<string, unknown>;
}
export interface ToolCallResponseInfo {
callId: string;
responsePart: Part;
resultDisplay: ToolResultDisplay | undefined;
error: Error | undefined;
}
export interface ServerToolCallConfirmationDetails {
request: ToolCallRequestInfo;
details: ToolCallConfirmationDetails;
}
export type ServerGeminiStreamEvent =
| { type: GeminiEventType.Content; value: string }
| { type: GeminiEventType.ToolCallRequest; value: ToolCallRequestInfo }
| { type: GeminiEventType.ToolCallResponse; value: ToolCallResponseInfo }
| {
type: GeminiEventType.ToolCallConfirmation;
value: ServerToolCallConfirmationDetails;
};
// --- Turn Class (Refactored for Server) ---
// A turn manages the agentic loop turn within the server context.
export class Turn {
private readonly chat: Chat;
private readonly availableTools: Map<string, ServerTool>; // Use passed-in tools
private pendingToolCalls: Array<{
callId: string;
name: string;
args: Record<string, unknown>; // Use unknown
}>;
private fnResponses: Part[];
private confirmationDetails: ToolCallConfirmationDetails[];
private debugResponses: GenerateContentResponse[];
constructor(chat: Chat, availableTools: ServerTool[]) {
this.chat = chat;
this.availableTools = new Map(availableTools.map((t) => [t.name, t]));
this.pendingToolCalls = [];
this.fnResponses = [];
this.confirmationDetails = [];
this.debugResponses = [];
}
// The run method yields simpler events suitable for server logic
async *run(
req: PartListUnion,
signal?: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent> {
const responseStream = await this.chat.sendMessageStream({ message: req });
for await (const resp of responseStream) {
this.debugResponses.push(resp);
if (signal?.aborted) {
throw this.abortError();
}
if (resp.text) {
yield { type: GeminiEventType.Content, value: resp.text };
continue;
}
if (!resp.functionCalls) {
continue;
}
// Handle function calls (requesting tool execution)
for (const fnCall of resp.functionCalls) {
const event = this.handlePendingFunctionCall(fnCall);
if (event) {
yield event;
}
}
}
// Execute pending tool calls
const toolPromises = this.pendingToolCalls.map(
async (pendingToolCall): Promise<ServerToolExecutionOutcome> => {
const tool = this.availableTools.get(pendingToolCall.name);
if (!tool) {
return {
...pendingToolCall,
error: new Error(
`Tool "${pendingToolCall.name}" not found or not provided to Turn.`,
),
confirmationDetails: undefined,
};
}
try {
const confirmationDetails = await tool.shouldConfirmExecute(
pendingToolCall.args,
);
if (confirmationDetails) {
return { ...pendingToolCall, confirmationDetails };
}
const result = await tool.execute(pendingToolCall.args);
return {
...pendingToolCall,
result,
confirmationDetails: undefined,
};
} catch (execError: unknown) {
return {
...pendingToolCall,
error: new Error(
`Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`,
),
confirmationDetails: undefined,
};
}
},
);
const outcomes = await Promise.all(toolPromises);
// Process outcomes and prepare function responses
this.pendingToolCalls = []; // Clear pending calls for this turn
for (const outcome of outcomes) {
if (outcome.confirmationDetails) {
this.confirmationDetails.push(outcome.confirmationDetails);
const serverConfirmationetails: ServerToolCallConfirmationDetails = {
request: {
callId: outcome.callId,
name: outcome.name,
args: outcome.args,
},
details: outcome.confirmationDetails,
};
yield {
type: GeminiEventType.ToolCallConfirmation,
value: serverConfirmationetails,
};
}
const responsePart = this.buildFunctionResponse(outcome);
this.fnResponses.push(responsePart);
const responseInfo: ToolCallResponseInfo = {
callId: outcome.callId,
responsePart,
resultDisplay: outcome.result?.returnDisplay,
error: outcome.error,
};
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
return;
}
}
// Generates a ToolCallRequest event to signal the need for execution
private handlePendingFunctionCall(
fnCall: FunctionCall,
): ServerGeminiStreamEvent | null {
const callId =
fnCall.id ??
`${fnCall.name}-${Date.now()}-${Math.random().toString(16).slice(2)}`;
const name = fnCall.name || 'undefined_tool_name';
const args = (fnCall.args || {}) as Record<string, unknown>;
this.pendingToolCalls.push({ callId, name, args });
// Yield a request for the tool call, not the pending/confirming status
const value: ToolCallRequestInfo = { callId, name, args };
return { type: GeminiEventType.ToolCallRequest, value };
}
// Builds the Part array expected by the Google GenAI API
private buildFunctionResponse(outcome: ServerToolExecutionOutcome): Part {
const { name, result, error } = outcome;
if (error) {
// Format error for the LLM
const errorMessage = error?.message || String(error);
console.error(`[Server Turn] Error executing tool ${name}:`, error);
return {
functionResponse: {
name,
id: outcome.callId,
response: { error: `Tool execution failed: ${errorMessage}` },
},
};
}
return {
functionResponse: {
name,
id: outcome.callId,
response: { output: result?.llmContent ?? '' },
},
};
}
private abortError(): Error {
const error = new Error('Request cancelled by user during stream.');
error.name = 'AbortError';
return error; // Return instead of throw, let caller handle
}
getConfirmationDetails(): ToolCallConfirmationDetails[] {
return this.confirmationDetails;
}
// Allows the service layer to get the responses needed for the next API call
getFunctionResponses(): Part[] {
return this.fnResponses;
}
// Debugging information (optional)
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}
}