feat: useToolScheduler hook to manage parallel tool calls (#448)
This commit is contained in:
parent
efee7c6cce
commit
02eec5c8ca
|
@ -134,7 +134,7 @@ export const App = ({
|
||||||
cliVersion,
|
cliVersion,
|
||||||
);
|
);
|
||||||
|
|
||||||
const { streamingState, submitQuery, initError, pendingHistoryItem } =
|
const { streamingState, submitQuery, initError, pendingHistoryItems } =
|
||||||
useGeminiStream(
|
useGeminiStream(
|
||||||
addItem,
|
addItem,
|
||||||
refreshStatic,
|
refreshStatic,
|
||||||
|
@ -209,7 +209,7 @@ export const App = ({
|
||||||
}, [terminalHeight, footerHeight]);
|
}, [terminalHeight, footerHeight]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!pendingHistoryItem) {
|
if (!pendingHistoryItems.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ export const App = ({
|
||||||
if (pendingItemDimensions.height > availableTerminalHeight) {
|
if (pendingItemDimensions.height > availableTerminalHeight) {
|
||||||
setStaticNeedsRefresh(true);
|
setStaticNeedsRefresh(true);
|
||||||
}
|
}
|
||||||
}, [pendingHistoryItem, availableTerminalHeight, streamingState]);
|
}, [pendingHistoryItems.length, availableTerminalHeight, streamingState]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (streamingState === StreamingState.Idle && staticNeedsRefresh) {
|
if (streamingState === StreamingState.Idle && staticNeedsRefresh) {
|
||||||
|
@ -264,17 +264,18 @@ export const App = ({
|
||||||
>
|
>
|
||||||
{(item) => item}
|
{(item) => item}
|
||||||
</Static>
|
</Static>
|
||||||
{pendingHistoryItem && (
|
<Box ref={pendingHistoryItemRef}>
|
||||||
<Box ref={pendingHistoryItemRef}>
|
{pendingHistoryItems.map((item, i) => (
|
||||||
<HistoryItemDisplay
|
<HistoryItemDisplay
|
||||||
|
key={i}
|
||||||
availableTerminalHeight={availableTerminalHeight}
|
availableTerminalHeight={availableTerminalHeight}
|
||||||
// TODO(taehykim): It seems like references to ids aren't necessary in
|
// TODO(taehykim): It seems like references to ids aren't necessary in
|
||||||
// HistoryItemDisplay. Refactor later. Use a fake id for now.
|
// HistoryItemDisplay. Refactor later. Use a fake id for now.
|
||||||
item={{ ...pendingHistoryItem, id: 0 }}
|
item={{ ...item, id: 0 }}
|
||||||
isPending={true}
|
isPending={true}
|
||||||
/>
|
/>
|
||||||
</Box>
|
))}
|
||||||
)}
|
</Box>
|
||||||
{showHelp && <Help commands={slashCommands} />}
|
{showHelp && <Help commands={slashCommands} />}
|
||||||
|
|
||||||
<Box flexDirection="column" ref={mainControlsRef}>
|
<Box flexDirection="column" ref={mainControlsRef}>
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import React from 'react';
|
import React, { useMemo } from 'react';
|
||||||
import { Box } from 'ink';
|
import { Box } from 'ink';
|
||||||
import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js';
|
import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js';
|
||||||
import { ToolMessage } from './ToolMessage.js';
|
import { ToolMessage } from './ToolMessage.js';
|
||||||
|
@ -19,7 +19,6 @@ interface ToolGroupMessageProps {
|
||||||
|
|
||||||
// Main component renders the border and maps the tools using ToolMessage
|
// Main component renders the border and maps the tools using ToolMessage
|
||||||
export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
||||||
groupId,
|
|
||||||
toolCalls,
|
toolCalls,
|
||||||
availableTerminalHeight,
|
availableTerminalHeight,
|
||||||
}) => {
|
}) => {
|
||||||
|
@ -30,9 +29,13 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
||||||
|
|
||||||
const staticHeight = /* border */ 2 + /* marginBottom */ 1;
|
const staticHeight = /* border */ 2 + /* marginBottom */ 1;
|
||||||
|
|
||||||
|
const toolAwaitingApproval = useMemo(
|
||||||
|
() => toolCalls.find((tc) => tc.status === ToolCallStatus.Confirming),
|
||||||
|
[toolCalls],
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box
|
<Box
|
||||||
key={groupId}
|
|
||||||
flexDirection="column"
|
flexDirection="column"
|
||||||
borderStyle="round"
|
borderStyle="round"
|
||||||
/*
|
/*
|
||||||
|
@ -48,7 +51,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
||||||
marginBottom={1}
|
marginBottom={1}
|
||||||
>
|
>
|
||||||
{toolCalls.map((tool) => (
|
{toolCalls.map((tool) => (
|
||||||
<Box key={groupId + '-' + tool.callId} flexDirection="column">
|
<Box key={tool.callId} flexDirection="column">
|
||||||
<ToolMessage
|
<ToolMessage
|
||||||
key={tool.callId}
|
key={tool.callId}
|
||||||
callId={tool.callId}
|
callId={tool.callId}
|
||||||
|
@ -60,6 +63,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
|
||||||
availableTerminalHeight={availableTerminalHeight - staticHeight}
|
availableTerminalHeight={availableTerminalHeight - staticHeight}
|
||||||
/>
|
/>
|
||||||
{tool.status === ToolCallStatus.Confirming &&
|
{tool.status === ToolCallStatus.Confirming &&
|
||||||
|
tool.callId === toolAwaitingApproval?.callId &&
|
||||||
tool.confirmationDetails && (
|
tool.confirmationDetails && (
|
||||||
<ToolConfirmationMessage
|
<ToolConfirmationMessage
|
||||||
confirmationDetails={tool.confirmationDetails}
|
confirmationDetails={tool.confirmationDetails}
|
||||||
|
|
|
@ -4,34 +4,28 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { useInput } from 'ink';
|
import { useInput } from 'ink';
|
||||||
import {
|
import {
|
||||||
GeminiClient,
|
GeminiClient,
|
||||||
GeminiEventType as ServerGeminiEventType,
|
GeminiEventType as ServerGeminiEventType,
|
||||||
ServerGeminiStreamEvent as GeminiEvent,
|
ServerGeminiStreamEvent as GeminiEvent,
|
||||||
ServerGeminiContentEvent as ContentEvent,
|
ServerGeminiContentEvent as ContentEvent,
|
||||||
ServerGeminiToolCallRequestEvent as ToolCallRequestEvent,
|
|
||||||
ServerGeminiToolCallResponseEvent as ToolCallResponseEvent,
|
|
||||||
ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent,
|
|
||||||
ServerGeminiErrorEvent as ErrorEvent,
|
ServerGeminiErrorEvent as ErrorEvent,
|
||||||
getErrorMessage,
|
getErrorMessage,
|
||||||
isNodeError,
|
isNodeError,
|
||||||
Config,
|
Config,
|
||||||
MessageSenderType,
|
MessageSenderType,
|
||||||
ServerToolCallConfirmationDetails,
|
ServerToolCallConfirmationDetails,
|
||||||
ToolCallConfirmationDetails,
|
|
||||||
ToolCallResponseInfo,
|
ToolCallResponseInfo,
|
||||||
ToolConfirmationOutcome,
|
|
||||||
ToolEditConfirmationDetails,
|
ToolEditConfirmationDetails,
|
||||||
ToolExecuteConfirmationDetails,
|
ToolExecuteConfirmationDetails,
|
||||||
ToolResultDisplay,
|
ToolResultDisplay,
|
||||||
partListUnionToString,
|
ToolCallRequestInfo,
|
||||||
} from '@gemini-code/server';
|
} from '@gemini-code/server';
|
||||||
import { type Chat, type PartListUnion, type Part } from '@google/genai';
|
import { type Chat, type PartListUnion, type Part } from '@google/genai';
|
||||||
import {
|
import {
|
||||||
StreamingState,
|
StreamingState,
|
||||||
IndividualToolCallDisplay,
|
|
||||||
ToolCallStatus,
|
ToolCallStatus,
|
||||||
HistoryItemWithoutId,
|
HistoryItemWithoutId,
|
||||||
HistoryItemToolGroup,
|
HistoryItemToolGroup,
|
||||||
|
@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||||
import { useStateAndRef } from './useStateAndRef.js';
|
import { useStateAndRef } from './useStateAndRef.js';
|
||||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||||
import { useLogger } from './useLogger.js';
|
import { useLogger } from './useLogger.js';
|
||||||
|
import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
|
||||||
|
|
||||||
enum StreamProcessingStatus {
|
enum StreamProcessingStatus {
|
||||||
Completed,
|
Completed,
|
||||||
|
@ -65,7 +60,6 @@ export const useGeminiStream = (
|
||||||
handleSlashCommand: (cmd: PartListUnion) => boolean,
|
handleSlashCommand: (cmd: PartListUnion) => boolean,
|
||||||
shellModeActive: boolean,
|
shellModeActive: boolean,
|
||||||
) => {
|
) => {
|
||||||
const toolRegistry = config.getToolRegistry();
|
|
||||||
const [initError, setInitError] = useState<string | null>(null);
|
const [initError, setInitError] = useState<string | null>(null);
|
||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
const chatSessionRef = useRef<Chat | null>(null);
|
const chatSessionRef = useRef<Chat | null>(null);
|
||||||
|
@ -74,6 +68,25 @@ export const useGeminiStream = (
|
||||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||||
const logger = useLogger();
|
const logger = useLogger();
|
||||||
|
const [toolCalls, schedule, cancel] = useToolScheduler((tools) => {
|
||||||
|
if (tools.length) {
|
||||||
|
addItem(mapToDisplay(tools), Date.now());
|
||||||
|
submitQuery(
|
||||||
|
tools
|
||||||
|
.filter(
|
||||||
|
(t) =>
|
||||||
|
t.status === 'error' ||
|
||||||
|
t.status === 'cancelled' ||
|
||||||
|
t.status === 'success',
|
||||||
|
)
|
||||||
|
.map((t) => t.response.responsePart),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}, config);
|
||||||
|
const pendingToolCalls = useMemo(
|
||||||
|
() => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
|
||||||
|
[toolCalls],
|
||||||
|
);
|
||||||
|
|
||||||
const onExec = useCallback(async (done: Promise<void>) => {
|
const onExec = useCallback(async (done: Promise<void>) => {
|
||||||
setIsResponding(true);
|
setIsResponding(true);
|
||||||
|
@ -104,6 +117,7 @@ export const useGeminiStream = (
|
||||||
useInput((_input, key) => {
|
useInput((_input, key) => {
|
||||||
if (streamingState !== StreamingState.Idle && key.escape) {
|
if (streamingState !== StreamingState.Idle && key.escape) {
|
||||||
abortControllerRef.current?.abort();
|
abortControllerRef.current?.abort();
|
||||||
|
cancel();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -215,157 +229,48 @@ export const useGeminiStream = (
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
const updateConfirmingFunctionStatusUI = (
|
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
|
||||||
callId: string,
|
// or could be a standalone helper if more params are passed.
|
||||||
confirmationDetails: ToolCallConfirmationDetails | undefined,
|
// TODO: handle file diff result display stuff
|
||||||
) => {
|
function _declineToolExecution(
|
||||||
setPendingHistoryItem((item) =>
|
declineMessage: string,
|
||||||
item?.type === 'tool_group'
|
status: ToolCallStatus,
|
||||||
? {
|
request: ServerToolCallConfirmationDetails['request'],
|
||||||
...item,
|
originalDetails: ServerToolCallConfirmationDetails['details'],
|
||||||
tools: item.tools.map((tool) =>
|
) {
|
||||||
tool.callId === callId
|
let resultDisplay: ToolResultDisplay | undefined;
|
||||||
? {
|
if ('fileDiff' in originalDetails) {
|
||||||
...tool,
|
resultDisplay = {
|
||||||
status: ToolCallStatus.Confirming,
|
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
|
||||||
confirmationDetails,
|
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
|
||||||
}
|
|
||||||
: tool,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
: item,
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const wireConfirmationSubmission = (
|
|
||||||
confirmationDetails: ServerToolCallConfirmationDetails,
|
|
||||||
): ToolCallConfirmationDetails => {
|
|
||||||
const originalConfirmationDetails = confirmationDetails.details;
|
|
||||||
const request = confirmationDetails.request;
|
|
||||||
const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => {
|
|
||||||
originalConfirmationDetails.onConfirm(outcome);
|
|
||||||
if (pendingHistoryItemRef?.current?.type === 'tool_group') {
|
|
||||||
setPendingHistoryItem((item) =>
|
|
||||||
item?.type === 'tool_group'
|
|
||||||
? {
|
|
||||||
...item,
|
|
||||||
tools: item.tools.map((tool) =>
|
|
||||||
tool.callId === request.callId
|
|
||||||
? {
|
|
||||||
...tool,
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
status: ToolCallStatus.Executing,
|
|
||||||
}
|
|
||||||
: tool,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
: item,
|
|
||||||
);
|
|
||||||
refreshStatic();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (outcome === ToolConfirmationOutcome.Cancel) {
|
|
||||||
declineToolExecution(
|
|
||||||
'User rejected function call.',
|
|
||||||
ToolCallStatus.Error,
|
|
||||||
request,
|
|
||||||
originalConfirmationDetails,
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
const tool = toolRegistry.getTool(request.name);
|
|
||||||
if (!tool) {
|
|
||||||
throw new Error(
|
|
||||||
`Tool "${request.name}" not found or is not registered.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
abortControllerRef.current = new AbortController();
|
|
||||||
const result = await tool.execute(
|
|
||||||
request.args,
|
|
||||||
abortControllerRef.current.signal,
|
|
||||||
);
|
|
||||||
if (abortControllerRef.current.signal.aborted) {
|
|
||||||
declineToolExecution(
|
|
||||||
partListUnionToString(result.llmContent),
|
|
||||||
ToolCallStatus.Canceled,
|
|
||||||
request,
|
|
||||||
originalConfirmationDetails,
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const functionResponse: Part = {
|
|
||||||
functionResponse: {
|
|
||||||
name: request.name,
|
|
||||||
id: request.callId,
|
|
||||||
response: { output: result.llmContent },
|
|
||||||
},
|
|
||||||
};
|
|
||||||
const responseInfo: ToolCallResponseInfo = {
|
|
||||||
callId: request.callId,
|
|
||||||
responsePart: functionResponse,
|
|
||||||
resultDisplay: result.returnDisplay,
|
|
||||||
error: undefined,
|
|
||||||
};
|
|
||||||
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
|
|
||||||
if (pendingHistoryItemRef.current) {
|
|
||||||
addItem(pendingHistoryItemRef.current, Date.now());
|
|
||||||
setPendingHistoryItem(null);
|
|
||||||
}
|
|
||||||
setIsResponding(false);
|
|
||||||
await submitQuery(functionResponse); // Recursive call
|
|
||||||
} finally {
|
|
||||||
if (streamingState !== StreamingState.WaitingForConfirmation) {
|
|
||||||
abortControllerRef.current = null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
|
|
||||||
// or could be a standalone helper if more params are passed.
|
|
||||||
function declineToolExecution(
|
|
||||||
declineMessage: string,
|
|
||||||
status: ToolCallStatus,
|
|
||||||
request: ServerToolCallConfirmationDetails['request'],
|
|
||||||
originalDetails: ServerToolCallConfirmationDetails['details'],
|
|
||||||
) {
|
|
||||||
let resultDisplay: ToolResultDisplay | undefined;
|
|
||||||
if ('fileDiff' in originalDetails) {
|
|
||||||
resultDisplay = {
|
|
||||||
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
|
|
||||||
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
|
|
||||||
}
|
|
||||||
const functionResponse: Part = {
|
|
||||||
functionResponse: {
|
|
||||||
id: request.callId,
|
|
||||||
name: request.name,
|
|
||||||
response: { error: declineMessage },
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
const responseInfo: ToolCallResponseInfo = {
|
} else {
|
||||||
callId: request.callId,
|
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
|
||||||
responsePart: functionResponse,
|
|
||||||
resultDisplay,
|
|
||||||
error: new Error(declineMessage),
|
|
||||||
};
|
|
||||||
const history = chatSessionRef.current?.getHistory();
|
|
||||||
if (history) {
|
|
||||||
history.push({ role: 'model', parts: [functionResponse] });
|
|
||||||
}
|
|
||||||
updateFunctionResponseUI(responseInfo, status);
|
|
||||||
if (pendingHistoryItemRef.current) {
|
|
||||||
addItem(pendingHistoryItemRef.current, Date.now());
|
|
||||||
setPendingHistoryItem(null);
|
|
||||||
}
|
|
||||||
setIsResponding(false);
|
|
||||||
}
|
}
|
||||||
|
const functionResponse: Part = {
|
||||||
return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
|
functionResponse: {
|
||||||
};
|
id: request.callId,
|
||||||
|
name: request.name,
|
||||||
|
response: { error: declineMessage },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const responseInfo: ToolCallResponseInfo = {
|
||||||
|
callId: request.callId,
|
||||||
|
responsePart: functionResponse,
|
||||||
|
resultDisplay,
|
||||||
|
error: new Error(declineMessage),
|
||||||
|
};
|
||||||
|
const history = chatSessionRef.current?.getHistory();
|
||||||
|
if (history) {
|
||||||
|
history.push({ role: 'model', parts: [functionResponse] });
|
||||||
|
}
|
||||||
|
updateFunctionResponseUI(responseInfo, status);
|
||||||
|
if (pendingHistoryItemRef.current) {
|
||||||
|
addItem(pendingHistoryItemRef.current, Date.now());
|
||||||
|
setPendingHistoryItem(null);
|
||||||
|
}
|
||||||
|
setIsResponding(false);
|
||||||
|
}
|
||||||
|
|
||||||
// --- Stream Event Handlers ---
|
// --- Stream Event Handlers ---
|
||||||
const handleContentEvent = (
|
const handleContentEvent = (
|
||||||
|
@ -419,62 +324,6 @@ export const useGeminiStream = (
|
||||||
return newGeminiMessageBuffer;
|
return newGeminiMessageBuffer;
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleToolCallRequestEvent = (
|
|
||||||
eventValue: ToolCallRequestEvent['value'],
|
|
||||||
userMessageTimestamp: number,
|
|
||||||
) => {
|
|
||||||
const { callId, name, args } = eventValue;
|
|
||||||
const cliTool = toolRegistry.getTool(name);
|
|
||||||
if (!cliTool) {
|
|
||||||
console.error(`CLI Tool "${name}" not found!`);
|
|
||||||
return; // Skip this event if tool is not found
|
|
||||||
}
|
|
||||||
if (pendingHistoryItemRef.current?.type !== 'tool_group') {
|
|
||||||
if (pendingHistoryItemRef.current) {
|
|
||||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
|
||||||
}
|
|
||||||
setPendingHistoryItem({ type: 'tool_group', tools: [] });
|
|
||||||
}
|
|
||||||
let description: string;
|
|
||||||
try {
|
|
||||||
description = cliTool.getDescription(args);
|
|
||||||
} catch (e) {
|
|
||||||
description = `Error: Unable to get description: ${getErrorMessage(e)}`;
|
|
||||||
}
|
|
||||||
const toolCallDisplay: IndividualToolCallDisplay = {
|
|
||||||
callId,
|
|
||||||
name: cliTool.displayName,
|
|
||||||
description,
|
|
||||||
status: ToolCallStatus.Pending,
|
|
||||||
resultDisplay: undefined,
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
};
|
|
||||||
setPendingHistoryItem((pending) =>
|
|
||||||
pending?.type === 'tool_group'
|
|
||||||
? { ...pending, tools: [...pending.tools, toolCallDisplay] }
|
|
||||||
: null,
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleToolCallResponseEvent = (
|
|
||||||
eventValue: ToolCallResponseEvent['value'],
|
|
||||||
) => {
|
|
||||||
const status = eventValue.error
|
|
||||||
? ToolCallStatus.Error
|
|
||||||
: ToolCallStatus.Success;
|
|
||||||
updateFunctionResponseUI(eventValue, status);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleToolCallConfirmationEvent = (
|
|
||||||
eventValue: ToolCallConfirmationEvent['value'],
|
|
||||||
) => {
|
|
||||||
const confirmationDetails = wireConfirmationSubmission(eventValue);
|
|
||||||
updateConfirmingFunctionStatusUI(
|
|
||||||
eventValue.request.callId,
|
|
||||||
confirmationDetails,
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
|
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
|
||||||
if (pendingHistoryItemRef.current) {
|
if (pendingHistoryItemRef.current) {
|
||||||
if (pendingHistoryItemRef.current.type === 'tool_group') {
|
if (pendingHistoryItemRef.current.type === 'tool_group') {
|
||||||
|
@ -500,6 +349,7 @@ export const useGeminiStream = (
|
||||||
userMessageTimestamp,
|
userMessageTimestamp,
|
||||||
);
|
);
|
||||||
setIsResponding(false);
|
setIsResponding(false);
|
||||||
|
cancel();
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleErrorEvent = (
|
const handleErrorEvent = (
|
||||||
|
@ -521,7 +371,7 @@ export const useGeminiStream = (
|
||||||
userMessageTimestamp: number,
|
userMessageTimestamp: number,
|
||||||
): Promise<StreamProcessingStatus> => {
|
): Promise<StreamProcessingStatus> => {
|
||||||
let geminiMessageBuffer = '';
|
let geminiMessageBuffer = '';
|
||||||
|
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||||
for await (const event of stream) {
|
for await (const event of stream) {
|
||||||
if (event.type === ServerGeminiEventType.Content) {
|
if (event.type === ServerGeminiEventType.Content) {
|
||||||
geminiMessageBuffer = handleContentEvent(
|
geminiMessageBuffer = handleContentEvent(
|
||||||
|
@ -530,12 +380,7 @@ export const useGeminiStream = (
|
||||||
userMessageTimestamp,
|
userMessageTimestamp,
|
||||||
);
|
);
|
||||||
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
|
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
|
||||||
handleToolCallRequestEvent(event.value, userMessageTimestamp);
|
toolCallRequests.push(event.value);
|
||||||
} else if (event.type === ServerGeminiEventType.ToolCallResponse) {
|
|
||||||
handleToolCallResponseEvent(event.value);
|
|
||||||
} else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
|
|
||||||
handleToolCallConfirmationEvent(event.value);
|
|
||||||
return StreamProcessingStatus.PausedForConfirmation;
|
|
||||||
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
||||||
handleUserCancelledEvent(userMessageTimestamp);
|
handleUserCancelledEvent(userMessageTimestamp);
|
||||||
return StreamProcessingStatus.UserCancelled;
|
return StreamProcessingStatus.UserCancelled;
|
||||||
|
@ -544,9 +389,18 @@ export const useGeminiStream = (
|
||||||
return StreamProcessingStatus.Error;
|
return StreamProcessingStatus.Error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
schedule(toolCallRequests);
|
||||||
return StreamProcessingStatus.Completed;
|
return StreamProcessingStatus.Completed;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const streamingState: StreamingState = isResponding
|
||||||
|
? StreamingState.Responding
|
||||||
|
: pendingToolCalls?.tools.some(
|
||||||
|
(t) => t.status === ToolCallStatus.Confirming,
|
||||||
|
)
|
||||||
|
? StreamingState.WaitingForConfirmation
|
||||||
|
: StreamingState.Idle;
|
||||||
|
|
||||||
const submitQuery = useCallback(
|
const submitQuery = useCallback(
|
||||||
async (query: PartListUnion) => {
|
async (query: PartListUnion) => {
|
||||||
if (isResponding) return;
|
if (isResponding) return;
|
||||||
|
@ -625,20 +479,15 @@ export const useGeminiStream = (
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
const streamingState: StreamingState = isResponding
|
const pendingHistoryItems = [
|
||||||
? StreamingState.Responding
|
pendingHistoryItemRef.current,
|
||||||
: pendingConfirmations(pendingHistoryItemRef.current)
|
pendingToolCalls,
|
||||||
? StreamingState.WaitingForConfirmation
|
].filter((i) => i !== undefined && i !== null);
|
||||||
: StreamingState.Idle;
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
streamingState,
|
streamingState,
|
||||||
submitQuery,
|
submitQuery,
|
||||||
initError,
|
initError,
|
||||||
pendingHistoryItem: pendingHistoryItemRef.current,
|
pendingHistoryItems,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
|
|
||||||
item?.type === 'tool_group' &&
|
|
||||||
item.tools.some((t) => t.status === ToolCallStatus.Confirming);
|
|
||||||
|
|
|
@ -155,10 +155,9 @@ export class GeminiClient {
|
||||||
signal?: AbortSignal,
|
signal?: AbortSignal,
|
||||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||||
let turns = 0;
|
let turns = 0;
|
||||||
const availableTools = this.config.getToolRegistry().getAllTools();
|
|
||||||
while (turns < this.MAX_TURNS) {
|
while (turns < this.MAX_TURNS) {
|
||||||
turns++;
|
turns++;
|
||||||
const turn = new Turn(chat, availableTools);
|
const turn = new Turn(chat);
|
||||||
const resultStream = turn.run(request, signal);
|
const resultStream = turn.run(request, signal);
|
||||||
let seenError = false;
|
let seenError = false;
|
||||||
for await (const event of resultStream) {
|
for await (const event of resultStream) {
|
||||||
|
|
|
@ -21,18 +21,6 @@ import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||||
import { reportError } from '../utils/errorReporting.js';
|
import { reportError } from '../utils/errorReporting.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
|
|
||||||
// --- Types for Server Logic ---
|
|
||||||
|
|
||||||
// Define a simpler structure for Tool execution results within the server
|
|
||||||
interface ServerToolExecutionOutcome {
|
|
||||||
callId: string;
|
|
||||||
name: string;
|
|
||||||
args: Record<string, unknown>;
|
|
||||||
result?: ToolResult;
|
|
||||||
error?: Error;
|
|
||||||
confirmationDetails: ToolCallConfirmationDetails | undefined;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define a structure for tools passed to the server
|
// Define a structure for tools passed to the server
|
||||||
export interface ServerTool {
|
export interface ServerTool {
|
||||||
name: string;
|
name: string;
|
||||||
|
@ -118,7 +106,6 @@ export type ServerGeminiStreamEvent =
|
||||||
|
|
||||||
// A turn manages the agentic loop turn within the server context.
|
// A turn manages the agentic loop turn within the server context.
|
||||||
export class Turn {
|
export class Turn {
|
||||||
private readonly availableTools: Map<string, ServerTool>;
|
|
||||||
private pendingToolCalls: Array<{
|
private pendingToolCalls: Array<{
|
||||||
callId: string;
|
callId: string;
|
||||||
name: string;
|
name: string;
|
||||||
|
@ -128,11 +115,7 @@ export class Turn {
|
||||||
private confirmationDetails: ToolCallConfirmationDetails[];
|
private confirmationDetails: ToolCallConfirmationDetails[];
|
||||||
private debugResponses: GenerateContentResponse[];
|
private debugResponses: GenerateContentResponse[];
|
||||||
|
|
||||||
constructor(
|
constructor(private readonly chat: Chat) {
|
||||||
private readonly chat: Chat,
|
|
||||||
availableTools: ServerTool[],
|
|
||||||
) {
|
|
||||||
this.availableTools = new Map(availableTools.map((t) => [t.name, t]));
|
|
||||||
this.pendingToolCalls = [];
|
this.pendingToolCalls = [];
|
||||||
this.fnResponses = [];
|
this.fnResponses = [];
|
||||||
this.confirmationDetails = [];
|
this.confirmationDetails = [];
|
||||||
|
@ -160,12 +143,9 @@ export class Turn {
|
||||||
yield { type: GeminiEventType.Content, value: text };
|
yield { type: GeminiEventType.Content, value: text };
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!resp.functionCalls) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle function calls (requesting tool execution)
|
// Handle function calls (requesting tool execution)
|
||||||
for (const fnCall of resp.functionCalls) {
|
const functionCalls = resp.functionCalls ?? [];
|
||||||
|
for (const fnCall of functionCalls) {
|
||||||
const event = this.handlePendingFunctionCall(fnCall);
|
const event = this.handlePendingFunctionCall(fnCall);
|
||||||
if (event) {
|
if (event) {
|
||||||
yield event;
|
yield event;
|
||||||
|
@ -184,80 +164,6 @@ export class Turn {
|
||||||
yield { type: GeminiEventType.Error, value: { message: errorMessage } };
|
yield { type: GeminiEventType.Error, value: { message: errorMessage } };
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute pending tool calls
|
|
||||||
const toolPromises = this.pendingToolCalls.map(
|
|
||||||
async (pendingToolCall): Promise<ServerToolExecutionOutcome> => {
|
|
||||||
const tool = this.availableTools.get(pendingToolCall.name);
|
|
||||||
if (!tool) {
|
|
||||||
return {
|
|
||||||
...pendingToolCall,
|
|
||||||
error: new Error(
|
|
||||||
`Tool "${pendingToolCall.name}" not found or not provided to Turn.`,
|
|
||||||
),
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
const confirmationDetails = await tool.shouldConfirmExecute(
|
|
||||||
pendingToolCall.args,
|
|
||||||
);
|
|
||||||
if (confirmationDetails) {
|
|
||||||
return { ...pendingToolCall, confirmationDetails };
|
|
||||||
}
|
|
||||||
const result = await tool.execute(pendingToolCall.args, signal);
|
|
||||||
return {
|
|
||||||
...pendingToolCall,
|
|
||||||
result,
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
};
|
|
||||||
} catch (execError: unknown) {
|
|
||||||
return {
|
|
||||||
...pendingToolCall,
|
|
||||||
error: new Error(
|
|
||||||
`Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`,
|
|
||||||
),
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
const outcomes = await Promise.all(toolPromises);
|
|
||||||
|
|
||||||
// Process outcomes and prepare function responses
|
|
||||||
this.pendingToolCalls = []; // Clear pending calls for this turn
|
|
||||||
|
|
||||||
for (const outcome of outcomes) {
|
|
||||||
if (outcome.confirmationDetails) {
|
|
||||||
this.confirmationDetails.push(outcome.confirmationDetails);
|
|
||||||
const serverConfirmationetails: ServerToolCallConfirmationDetails = {
|
|
||||||
request: {
|
|
||||||
callId: outcome.callId,
|
|
||||||
name: outcome.name,
|
|
||||||
args: outcome.args,
|
|
||||||
},
|
|
||||||
details: outcome.confirmationDetails,
|
|
||||||
};
|
|
||||||
yield {
|
|
||||||
type: GeminiEventType.ToolCallConfirmation,
|
|
||||||
value: serverConfirmationetails,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
const responsePart = this.buildFunctionResponse(outcome);
|
|
||||||
this.fnResponses.push(responsePart);
|
|
||||||
const responseInfo: ToolCallResponseInfo = {
|
|
||||||
callId: outcome.callId,
|
|
||||||
responsePart,
|
|
||||||
resultDisplay: outcome.result?.returnDisplay,
|
|
||||||
error: outcome.error,
|
|
||||||
};
|
|
||||||
|
|
||||||
// If aborted we're already yielding the user cancellations elsewhere.
|
|
||||||
if (!signal?.aborted) {
|
|
||||||
yield { type: GeminiEventType.ToolCallResponse, value: responseInfo };
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private handlePendingFunctionCall(
|
private handlePendingFunctionCall(
|
||||||
|
@ -276,30 +182,6 @@ export class Turn {
|
||||||
return { type: GeminiEventType.ToolCallRequest, value };
|
return { type: GeminiEventType.ToolCallRequest, value };
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builds the Part array expected by the Google GenAI API
|
|
||||||
private buildFunctionResponse(outcome: ServerToolExecutionOutcome): Part {
|
|
||||||
const { name, result, error } = outcome;
|
|
||||||
if (error) {
|
|
||||||
// Format error for the LLM
|
|
||||||
const errorMessage = error?.message || String(error);
|
|
||||||
console.error(`[Server Turn] Error executing tool ${name}:`, error);
|
|
||||||
return {
|
|
||||||
functionResponse: {
|
|
||||||
name,
|
|
||||||
id: outcome.callId,
|
|
||||||
response: { error: `Tool execution failed: ${errorMessage}` },
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
functionResponse: {
|
|
||||||
name,
|
|
||||||
id: outcome.callId,
|
|
||||||
response: { output: result?.llmContent ?? '' },
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
getConfirmationDetails(): ToolCallConfirmationDetails[] {
|
getConfirmationDetails(): ToolCallConfirmationDetails[] {
|
||||||
return this.confirmationDetails;
|
return this.confirmationDetails;
|
||||||
}
|
}
|
||||||
|
|
|
@ -171,23 +171,28 @@ export interface FileDiff {
|
||||||
fileName: string;
|
fileName: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ToolCallConfirmationDetails {
|
export interface ToolCallConfirmationDetailsDefault {
|
||||||
title: string;
|
title: string;
|
||||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ToolEditConfirmationDetails
|
export interface ToolEditConfirmationDetails
|
||||||
extends ToolCallConfirmationDetails {
|
extends ToolCallConfirmationDetailsDefault {
|
||||||
fileName: string;
|
fileName: string;
|
||||||
fileDiff: string;
|
fileDiff: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ToolExecuteConfirmationDetails
|
export interface ToolExecuteConfirmationDetails
|
||||||
extends ToolCallConfirmationDetails {
|
extends ToolCallConfirmationDetailsDefault {
|
||||||
command: string;
|
command: string;
|
||||||
rootCommand: string;
|
rootCommand: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type ToolCallConfirmationDetails =
|
||||||
|
| ToolCallConfirmationDetailsDefault
|
||||||
|
| ToolEditConfirmationDetails
|
||||||
|
| ToolExecuteConfirmationDetails;
|
||||||
|
|
||||||
export enum ToolConfirmationOutcome {
|
export enum ToolConfirmationOutcome {
|
||||||
ProceedOnce,
|
ProceedOnce,
|
||||||
ProceedAlways,
|
ProceedAlways,
|
||||||
|
|
Loading…
Reference in New Issue