feat: useToolScheduler hook to manage parallel tool calls (#448)
This commit is contained in:
parent
efee7c6cce
commit
02eec5c8ca
|
@ -134,7 +134,7 @@ export const App = ({
|
|||
cliVersion,
|
||||
);
|
||||
|
||||
const { streamingState, submitQuery, initError, pendingHistoryItem } =
|
||||
const { streamingState, submitQuery, initError, pendingHistoryItems } =
|
||||
useGeminiStream(
|
||||
addItem,
|
||||
refreshStatic,
|
||||
|
@ -209,7 +209,7 @@ export const App = ({
|
|||
}, [terminalHeight, footerHeight]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!pendingHistoryItem) {
|
||||
if (!pendingHistoryItems.length) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -223,7 +223,7 @@ export const App = ({
|
|||
if (pendingItemDimensions.height > availableTerminalHeight) {
|
||||
setStaticNeedsRefresh(true);
|
||||
}
|
||||
}, [pendingHistoryItem, availableTerminalHeight, streamingState]);
|
||||
}, [pendingHistoryItems.length, availableTerminalHeight, streamingState]);
|
||||
|
||||
useEffect(() => {
|
||||
if (streamingState === StreamingState.Idle && staticNeedsRefresh) {
|
||||
|
@ -264,17 +264,18 @@ export const App = ({
|
|||
>
|
||||
{(item) => item}
|
||||
</Static>
|
||||
{pendingHistoryItem && (
|
||||
<Box ref={pendingHistoryItemRef}>
|
||||
<Box ref={pendingHistoryItemRef}>
|
||||
{pendingHistoryItems.map((item, i) => (
|
||||
<HistoryItemDisplay
|
||||
key={i}
|
||||
availableTerminalHeight={availableTerminalHeight}
|
||||
// TODO(taehykim): It seems like references to ids aren't necessary in
|
||||
// HistoryItemDisplay. Refactor later. Use a fake id for now.
|
||||
item={{ ...pendingHistoryItem, id: 0 }}
|
||||
item={{ ...item, id: 0 }}
|
||||
isPending={true}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
))}
|
||||
</Box>
|
||||
{showHelp && <Help commands={slashCommands} />}
|
||||
|
||||
<Box flexDirection="column" ref={mainControlsRef}>
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import React, { useMemo } from 'react';
|
||||
import { Box } from 'ink';
|
||||
import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js';
|
||||
import { ToolMessage } from './ToolMessage.js';
|
||||
|
@ -19,7 +19,6 @@ interface ToolGroupMessageProps {
|
|||
|
||||
// Main component renders the border and maps the tools using ToolMessage
|
||||
export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
||||
groupId,
|
||||
toolCalls,
|
||||
availableTerminalHeight,
|
||||
}) => {
|
||||
|
@ -30,9 +29,13 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
|||
|
||||
const staticHeight = /* border */ 2 + /* marginBottom */ 1;
|
||||
|
||||
const toolAwaitingApproval = useMemo(
|
||||
() => toolCalls.find((tc) => tc.status === ToolCallStatus.Confirming),
|
||||
[toolCalls],
|
||||
);
|
||||
|
||||
return (
|
||||
<Box
|
||||
key={groupId}
|
||||
flexDirection="column"
|
||||
borderStyle="round"
|
||||
/*
|
||||
|
@ -48,7 +51,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
|||
marginBottom={1}
|
||||
>
|
||||
{toolCalls.map((tool) => (
|
||||
<Box key={groupId + '-' + tool.callId} flexDirection="column">
|
||||
<Box key={tool.callId} flexDirection="column">
|
||||
<ToolMessage
|
||||
key={tool.callId}
|
||||
callId={tool.callId}
|
||||
|
@ -60,6 +63,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
|||
availableTerminalHeight={availableTerminalHeight - staticHeight}
|
||||
/>
|
||||
{tool.status === ToolCallStatus.Confirming &&
|
||||
tool.callId === toolAwaitingApproval?.callId &&
|
||||
tool.confirmationDetails && (
|
||||
<ToolConfirmationMessage
|
||||
confirmationDetails={tool.confirmationDetails}
|
||||
|
|
|
@ -4,34 +4,28 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useInput } from 'ink';
|
||||
import {
|
||||
GeminiClient,
|
||||
GeminiEventType as ServerGeminiEventType,
|
||||
ServerGeminiStreamEvent as GeminiEvent,
|
||||
ServerGeminiContentEvent as ContentEvent,
|
||||
ServerGeminiToolCallRequestEvent as ToolCallRequestEvent,
|
||||
ServerGeminiToolCallResponseEvent as ToolCallResponseEvent,
|
||||
ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent,
|
||||
ServerGeminiErrorEvent as ErrorEvent,
|
||||
getErrorMessage,
|
||||
isNodeError,
|
||||
Config,
|
||||
MessageSenderType,
|
||||
ServerToolCallConfirmationDetails,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolCallResponseInfo,
|
||||
ToolConfirmationOutcome,
|
||||
ToolEditConfirmationDetails,
|
||||
ToolExecuteConfirmationDetails,
|
||||
ToolResultDisplay,
|
||||
partListUnionToString,
|
||||
ToolCallRequestInfo,
|
||||
} from '@gemini-code/server';
|
||||
import { type Chat, type PartListUnion, type Part } from '@google/genai';
|
||||
import {
|
||||
StreamingState,
|
||||
IndividualToolCallDisplay,
|
||||
ToolCallStatus,
|
||||
HistoryItemWithoutId,
|
||||
HistoryItemToolGroup,
|
||||
|
@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
|||
import { useStateAndRef } from './useStateAndRef.js';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { useLogger } from './useLogger.js';
|
||||
import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
|
||||
|
||||
enum StreamProcessingStatus {
|
||||
Completed,
|
||||
|
@ -65,7 +60,6 @@ export const useGeminiStream = (
|
|||
handleSlashCommand: (cmd: PartListUnion) => boolean,
|
||||
shellModeActive: boolean,
|
||||
) => {
|
||||
const toolRegistry = config.getToolRegistry();
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const chatSessionRef = useRef<Chat | null>(null);
|
||||
|
@ -74,6 +68,25 @@ export const useGeminiStream = (
|
|||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||
const logger = useLogger();
|
||||
const [toolCalls, schedule, cancel] = useToolScheduler((tools) => {
|
||||
if (tools.length) {
|
||||
addItem(mapToDisplay(tools), Date.now());
|
||||
submitQuery(
|
||||
tools
|
||||
.filter(
|
||||
(t) =>
|
||||
t.status === 'error' ||
|
||||
t.status === 'cancelled' ||
|
||||
t.status === 'success',
|
||||
)
|
||||
.map((t) => t.response.responsePart),
|
||||
);
|
||||
}
|
||||
}, config);
|
||||
const pendingToolCalls = useMemo(
|
||||
() => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
|
||||
[toolCalls],
|
||||
);
|
||||
|
||||
const onExec = useCallback(async (done: Promise<void>) => {
|
||||
setIsResponding(true);
|
||||
|
@ -104,6 +117,7 @@ export const useGeminiStream = (
|
|||
useInput((_input, key) => {
|
||||
if (streamingState !== StreamingState.Idle && key.escape) {
|
||||
abortControllerRef.current?.abort();
|
||||
cancel();
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -215,157 +229,48 @@ export const useGeminiStream = (
|
|||
);
|
||||
};
|
||||
|
||||
const updateConfirmingFunctionStatusUI = (
|
||||
callId: string,
|
||||
confirmationDetails: ToolCallConfirmationDetails | undefined,
|
||||
) => {
|
||||
setPendingHistoryItem((item) =>
|
||||
item?.type === 'tool_group'
|
||||
? {
|
||||
...item,
|
||||
tools: item.tools.map((tool) =>
|
||||
tool.callId === callId
|
||||
? {
|
||||
...tool,
|
||||
status: ToolCallStatus.Confirming,
|
||||
confirmationDetails,
|
||||
}
|
||||
: tool,
|
||||
),
|
||||
}
|
||||
: item,
|
||||
);
|
||||
};
|
||||
|
||||
const wireConfirmationSubmission = (
|
||||
confirmationDetails: ServerToolCallConfirmationDetails,
|
||||
): ToolCallConfirmationDetails => {
|
||||
const originalConfirmationDetails = confirmationDetails.details;
|
||||
const request = confirmationDetails.request;
|
||||
const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => {
|
||||
originalConfirmationDetails.onConfirm(outcome);
|
||||
if (pendingHistoryItemRef?.current?.type === 'tool_group') {
|
||||
setPendingHistoryItem((item) =>
|
||||
item?.type === 'tool_group'
|
||||
? {
|
||||
...item,
|
||||
tools: item.tools.map((tool) =>
|
||||
tool.callId === request.callId
|
||||
? {
|
||||
...tool,
|
||||
confirmationDetails: undefined,
|
||||
status: ToolCallStatus.Executing,
|
||||
}
|
||||
: tool,
|
||||
),
|
||||
}
|
||||
: item,
|
||||
);
|
||||
refreshStatic();
|
||||
}
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.Cancel) {
|
||||
declineToolExecution(
|
||||
'User rejected function call.',
|
||||
ToolCallStatus.Error,
|
||||
request,
|
||||
originalConfirmationDetails,
|
||||
);
|
||||
} else {
|
||||
const tool = toolRegistry.getTool(request.name);
|
||||
if (!tool) {
|
||||
throw new Error(
|
||||
`Tool "${request.name}" not found or is not registered.`,
|
||||
);
|
||||
}
|
||||
try {
|
||||
abortControllerRef.current = new AbortController();
|
||||
const result = await tool.execute(
|
||||
request.args,
|
||||
abortControllerRef.current.signal,
|
||||
);
|
||||
if (abortControllerRef.current.signal.aborted) {
|
||||
declineToolExecution(
|
||||
partListUnionToString(result.llmContent),
|
||||
ToolCallStatus.Canceled,
|
||||
request,
|
||||
originalConfirmationDetails,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
name: request.name,
|
||||
id: request.callId,
|
||||
response: { output: result.llmContent },
|
||||
},
|
||||
};
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay: result.returnDisplay,
|
||||
error: undefined,
|
||||
};
|
||||
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
setIsResponding(false);
|
||||
await submitQuery(functionResponse); // Recursive call
|
||||
} finally {
|
||||
if (streamingState !== StreamingState.WaitingForConfirmation) {
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
|
||||
// or could be a standalone helper if more params are passed.
|
||||
function declineToolExecution(
|
||||
declineMessage: string,
|
||||
status: ToolCallStatus,
|
||||
request: ServerToolCallConfirmationDetails['request'],
|
||||
originalDetails: ServerToolCallConfirmationDetails['details'],
|
||||
) {
|
||||
let resultDisplay: ToolResultDisplay | undefined;
|
||||
if ('fileDiff' in originalDetails) {
|
||||
resultDisplay = {
|
||||
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
|
||||
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
|
||||
};
|
||||
} else {
|
||||
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
|
||||
}
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
id: request.callId,
|
||||
name: request.name,
|
||||
response: { error: declineMessage },
|
||||
},
|
||||
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
|
||||
// or could be a standalone helper if more params are passed.
|
||||
// TODO: handle file diff result display stuff
|
||||
function _declineToolExecution(
|
||||
declineMessage: string,
|
||||
status: ToolCallStatus,
|
||||
request: ServerToolCallConfirmationDetails['request'],
|
||||
originalDetails: ServerToolCallConfirmationDetails['details'],
|
||||
) {
|
||||
let resultDisplay: ToolResultDisplay | undefined;
|
||||
if ('fileDiff' in originalDetails) {
|
||||
resultDisplay = {
|
||||
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
|
||||
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
|
||||
};
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay,
|
||||
error: new Error(declineMessage),
|
||||
};
|
||||
const history = chatSessionRef.current?.getHistory();
|
||||
if (history) {
|
||||
history.push({ role: 'model', parts: [functionResponse] });
|
||||
}
|
||||
updateFunctionResponseUI(responseInfo, status);
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
setIsResponding(false);
|
||||
} else {
|
||||
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
|
||||
}
|
||||
|
||||
return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
|
||||
};
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
id: request.callId,
|
||||
name: request.name,
|
||||
response: { error: declineMessage },
|
||||
},
|
||||
};
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay,
|
||||
error: new Error(declineMessage),
|
||||
};
|
||||
const history = chatSessionRef.current?.getHistory();
|
||||
if (history) {
|
||||
history.push({ role: 'model', parts: [functionResponse] });
|
||||
}
|
||||
updateFunctionResponseUI(responseInfo, status);
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
setIsResponding(false);
|
||||
}
|
||||
|
||||
// --- Stream Event Handlers ---
|
||||
const handleContentEvent = (
|
||||
|
@ -419,62 +324,6 @@ export const useGeminiStream = (
|
|||
return newGeminiMessageBuffer;
|
||||
};
|
||||
|
||||
const handleToolCallRequestEvent = (
|
||||
eventValue: ToolCallRequestEvent['value'],
|
||||
userMessageTimestamp: number,
|
||||
) => {
|
||||
const { callId, name, args } = eventValue;
|
||||
const cliTool = toolRegistry.getTool(name);
|
||||
if (!cliTool) {
|
||||
console.error(`CLI Tool "${name}" not found!`);
|
||||
return; // Skip this event if tool is not found
|
||||
}
|
||||
if (pendingHistoryItemRef.current?.type !== 'tool_group') {
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
}
|
||||
setPendingHistoryItem({ type: 'tool_group', tools: [] });
|
||||
}
|
||||
let description: string;
|
||||
try {
|
||||
description = cliTool.getDescription(args);
|
||||
} catch (e) {
|
||||
description = `Error: Unable to get description: ${getErrorMessage(e)}`;
|
||||
}
|
||||
const toolCallDisplay: IndividualToolCallDisplay = {
|
||||
callId,
|
||||
name: cliTool.displayName,
|
||||
description,
|
||||
status: ToolCallStatus.Pending,
|
||||
resultDisplay: undefined,
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
setPendingHistoryItem((pending) =>
|
||||
pending?.type === 'tool_group'
|
||||
? { ...pending, tools: [...pending.tools, toolCallDisplay] }
|
||||
: null,
|
||||
);
|
||||
};
|
||||
|
||||
const handleToolCallResponseEvent = (
|
||||
eventValue: ToolCallResponseEvent['value'],
|
||||
) => {
|
||||
const status = eventValue.error
|
||||
? ToolCallStatus.Error
|
||||
: ToolCallStatus.Success;
|
||||
updateFunctionResponseUI(eventValue, status);
|
||||
};
|
||||
|
||||
const handleToolCallConfirmationEvent = (
|
||||
eventValue: ToolCallConfirmationEvent['value'],
|
||||
) => {
|
||||
const confirmationDetails = wireConfirmationSubmission(eventValue);
|
||||
updateConfirmingFunctionStatusUI(
|
||||
eventValue.request.callId,
|
||||
confirmationDetails,
|
||||
);
|
||||
};
|
||||
|
||||
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
|
||||
if (pendingHistoryItemRef.current) {
|
||||
if (pendingHistoryItemRef.current.type === 'tool_group') {
|
||||
|
@ -500,6 +349,7 @@ export const useGeminiStream = (
|
|||
userMessageTimestamp,
|
||||
);
|
||||
setIsResponding(false);
|
||||
cancel();
|
||||
};
|
||||
|
||||
const handleErrorEvent = (
|
||||
|
@ -521,7 +371,7 @@ export const useGeminiStream = (
|
|||
userMessageTimestamp: number,
|
||||
): Promise<StreamProcessingStatus> => {
|
||||
let geminiMessageBuffer = '';
|
||||
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
for await (const event of stream) {
|
||||
if (event.type === ServerGeminiEventType.Content) {
|
||||
geminiMessageBuffer = handleContentEvent(
|
||||
|
@ -530,12 +380,7 @@ export const useGeminiStream = (
|
|||
userMessageTimestamp,
|
||||
);
|
||||
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
|
||||
handleToolCallRequestEvent(event.value, userMessageTimestamp);
|
||||
} else if (event.type === ServerGeminiEventType.ToolCallResponse) {
|
||||
handleToolCallResponseEvent(event.value);
|
||||
} else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
|
||||
handleToolCallConfirmationEvent(event.value);
|
||||
return StreamProcessingStatus.PausedForConfirmation;
|
||||
toolCallRequests.push(event.value);
|
||||
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
||||
handleUserCancelledEvent(userMessageTimestamp);
|
||||
return StreamProcessingStatus.UserCancelled;
|
||||
|
@ -544,9 +389,18 @@ export const useGeminiStream = (
|
|||
return StreamProcessingStatus.Error;
|
||||
}
|
||||
}
|
||||
schedule(toolCallRequests);
|
||||
return StreamProcessingStatus.Completed;
|
||||
};
|
||||
|
||||
const streamingState: StreamingState = isResponding
|
||||
? StreamingState.Responding
|
||||
: pendingToolCalls?.tools.some(
|
||||
(t) => t.status === ToolCallStatus.Confirming,
|
||||
)
|
||||
? StreamingState.WaitingForConfirmation
|
||||
: StreamingState.Idle;
|
||||
|
||||
const submitQuery = useCallback(
|
||||
async (query: PartListUnion) => {
|
||||
if (isResponding) return;
|
||||
|
@ -625,20 +479,15 @@ export const useGeminiStream = (
|
|||
],
|
||||
);
|
||||
|
||||
const streamingState: StreamingState = isResponding
|
||||
? StreamingState.Responding
|
||||
: pendingConfirmations(pendingHistoryItemRef.current)
|
||||
? StreamingState.WaitingForConfirmation
|
||||
: StreamingState.Idle;
|
||||
const pendingHistoryItems = [
|
||||
pendingHistoryItemRef.current,
|
||||
pendingToolCalls,
|
||||
].filter((i) => i !== undefined && i !== null);
|
||||
|
||||
return {
|
||||
streamingState,
|
||||
submitQuery,
|
||||
initError,
|
||||
pendingHistoryItem: pendingHistoryItemRef.current,
|
||||
pendingHistoryItems,
|
||||
};
|
||||
};
|
||||
|
||||
const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
|
||||
item?.type === 'tool_group' &&
|
||||
item.tools.some((t) => t.status === ToolCallStatus.Confirming);
|
||||
|
|
|
@ -155,10 +155,9 @@ export class GeminiClient {
|
|||
signal?: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
let turns = 0;
|
||||
const availableTools = this.config.getToolRegistry().getAllTools();
|
||||
while (turns < this.MAX_TURNS) {
|
||||
turns++;
|
||||
const turn = new Turn(chat, availableTools);
|
||||
const turn = new Turn(chat);
|
||||
const resultStream = turn.run(request, signal);
|
||||
let seenError = false;
|
||||
for await (const event of resultStream) {
|
||||
|
|
|
@ -21,18 +21,6 @@ import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
|||
import { reportError } from '../utils/errorReporting.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
// --- Types for Server Logic ---
|
||||
|
||||
// Define a simpler structure for Tool execution results within the server
|
||||
interface ServerToolExecutionOutcome {
|
||||
callId: string;
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
result?: ToolResult;
|
||||
error?: Error;
|
||||
confirmationDetails: ToolCallConfirmationDetails | undefined;
|
||||
}
|
||||
|
||||
// Define a structure for tools passed to the server
|
||||
export interface ServerTool {
|
||||
name: string;
|
||||
|
@ -118,7 +106,6 @@ export type ServerGeminiStreamEvent =
|
|||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
private readonly availableTools: Map<string, ServerTool>;
|
||||
private pendingToolCalls: Array<{
|
||||
callId: string;
|
||||
name: string;
|
||||
|
@ -128,11 +115,7 @@ export class Turn {
|
|||
private confirmationDetails: ToolCallConfirmationDetails[];
|
||||
private debugResponses: GenerateContentResponse[];
|
||||
|
||||
constructor(
|
||||
private readonly chat: Chat,
|
||||
availableTools: ServerTool[],
|
||||
) {
|
||||
this.availableTools = new Map(availableTools.map((t) => [t.name, t]));
|
||||
constructor(private readonly chat: Chat) {
|
||||
this.pendingToolCalls = [];
|
||||
this.fnResponses = [];
|
||||
this.confirmationDetails = [];
|
||||
|
@ -160,12 +143,9 @@ export class Turn {
|
|||
yield { type: GeminiEventType.Content, value: text };
|
||||
}
|
||||
|
||||
if (!resp.functionCalls) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle function calls (requesting tool execution)
|
||||
for (const fnCall of resp.functionCalls) {
|
||||
const functionCalls = resp.functionCalls ?? [];
|
||||
for (const fnCall of functionCalls) {
|
||||
const event = this.handlePendingFunctionCall(fnCall);
|
||||
if (event) {
|
||||
yield event;
|
||||
|
@ -184,80 +164,6 @@ export class Turn {
|
|||
yield { type: GeminiEventType.Error, value: { message: errorMessage } };
|
||||
return;
|
||||
}
|
||||
|
||||
// Execute pending tool calls
|
||||
const toolPromises = this.pendingToolCalls.map(
|
||||
async (pendingToolCall): Promise<ServerToolExecutionOutcome> => {
|
||||
const tool = this.availableTools.get(pendingToolCall.name);
|
||||
if (!tool) {
|
||||
return {
|
||||
...pendingToolCall,
|
||||
error: new Error(
|
||||
`Tool "${pendingToolCall.name}" not found or not provided to Turn.`,
|
||||
),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const confirmationDetails = await tool.shouldConfirmExecute(
|
||||
pendingToolCall.args,
|
||||
);
|
||||
if (confirmationDetails) {
|
||||
return { ...pendingToolCall, confirmationDetails };
|
||||
}
|
||||
const result = await tool.execute(pendingToolCall.args, signal);
|
||||
return {
|
||||
...pendingToolCall,
|
||||
result,
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
} catch (execError: unknown) {
|
||||
return {
|
||||
...pendingToolCall,
|
||||
error: new Error(
|
||||
`Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`,
|
||||
),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
}
|
||||
},
|
||||
);
|
||||
const outcomes = await Promise.all(toolPromises);
|
||||
|
||||
// Process outcomes and prepare function responses
|
||||
this.pendingToolCalls = []; // Clear pending calls for this turn
|
||||
|
||||
for (const outcome of outcomes) {
|
||||
if (outcome.confirmationDetails) {
|
||||
this.confirmationDetails.push(outcome.confirmationDetails);
|
||||
const serverConfirmationetails: ServerToolCallConfirmationDetails = {
|
||||
request: {
|
||||
callId: outcome.callId,
|
||||
name: outcome.name,
|
||||
args: outcome.args,
|
||||
},
|
||||
details: outcome.confirmationDetails,
|
||||
};
|
||||
yield {
|
||||
type: GeminiEventType.ToolCallConfirmation,
|
||||
value: serverConfirmationetails,
|
||||
};
|
||||
}
|
||||
const responsePart = this.buildFunctionResponse(outcome);
|
||||
this.fnResponses.push(responsePart);
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: outcome.callId,
|
||||
responsePart,
|
||||
resultDisplay: outcome.result?.returnDisplay,
|
||||
error: outcome.error,
|
||||
};
|
||||
|
||||
// If aborted we're already yielding the user cancellations elsewhere.
|
||||
if (!signal?.aborted) {
|
||||
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private handlePendingFunctionCall(
|
||||
|
@ -276,30 +182,6 @@ export class Turn {
|
|||
return { type: GeminiEventType.ToolCallRequest, value };
|
||||
}
|
||||
|
||||
// Builds the Part array expected by the Google GenAI API
|
||||
private buildFunctionResponse(outcome: ServerToolExecutionOutcome): Part {
|
||||
const { name, result, error } = outcome;
|
||||
if (error) {
|
||||
// Format error for the LLM
|
||||
const errorMessage = error?.message || String(error);
|
||||
console.error(`[Server Turn] Error executing tool ${name}:`, error);
|
||||
return {
|
||||
functionResponse: {
|
||||
name,
|
||||
id: outcome.callId,
|
||||
response: { error: `Tool execution failed: ${errorMessage}` },
|
||||
},
|
||||
};
|
||||
}
|
||||
return {
|
||||
functionResponse: {
|
||||
name,
|
||||
id: outcome.callId,
|
||||
response: { output: result?.llmContent ?? '' },
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
getConfirmationDetails(): ToolCallConfirmationDetails[] {
|
||||
return this.confirmationDetails;
|
||||
}
|
||||
|
|
|
@ -171,23 +171,28 @@ export interface FileDiff {
|
|||
fileName: string;
|
||||
}
|
||||
|
||||
export interface ToolCallConfirmationDetails {
|
||||
export interface ToolCallConfirmationDetailsDefault {
|
||||
title: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
}
|
||||
|
||||
export interface ToolEditConfirmationDetails
|
||||
extends ToolCallConfirmationDetails {
|
||||
extends ToolCallConfirmationDetailsDefault {
|
||||
fileName: string;
|
||||
fileDiff: string;
|
||||
}
|
||||
|
||||
export interface ToolExecuteConfirmationDetails
|
||||
extends ToolCallConfirmationDetails {
|
||||
extends ToolCallConfirmationDetailsDefault {
|
||||
command: string;
|
||||
rootCommand: string;
|
||||
}
|
||||
|
||||
export type ToolCallConfirmationDetails =
|
||||
| ToolCallConfirmationDetailsDefault
|
||||
| ToolEditConfirmationDetails
|
||||
| ToolExecuteConfirmationDetails;
|
||||
|
||||
export enum ToolConfirmationOutcome {
|
||||
ProceedOnce,
|
||||
ProceedAlways,
|
||||
|
|
Loading…
Reference in New Issue