refactor: Centralize tool scheduling logic and simplify React hook (#670)
This commit is contained in:
parent
edc12e416d
commit
f2a8d39f42
|
@ -200,7 +200,6 @@ export const App = ({
|
||||||
const { streamingState, submitQuery, initError, pendingHistoryItems } =
|
const { streamingState, submitQuery, initError, pendingHistoryItems } =
|
||||||
useGeminiStream(
|
useGeminiStream(
|
||||||
addItem,
|
addItem,
|
||||||
refreshStatic,
|
|
||||||
setShowHelp,
|
setShowHelp,
|
||||||
config,
|
config,
|
||||||
setDebugMessage,
|
setDebugMessage,
|
||||||
|
|
|
@ -9,11 +9,11 @@ import { mergePartListUnions } from './useGeminiStream.js';
|
||||||
import { Part, PartListUnion } from '@google/genai';
|
import { Part, PartListUnion } from '@google/genai';
|
||||||
|
|
||||||
// Mock useToolScheduler
|
// Mock useToolScheduler
|
||||||
vi.mock('./useToolScheduler', async () => {
|
vi.mock('./useReactToolScheduler', async () => {
|
||||||
const actual = await vi.importActual('./useToolScheduler');
|
const actual = await vi.importActual('./useReactToolScheduler');
|
||||||
return {
|
return {
|
||||||
...actual, // We need mapToDisplay from actual
|
...actual, // We need mapToDisplay from actual
|
||||||
useToolScheduler: vi.fn(),
|
useReactToolScheduler: vi.fn(),
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
|
@ -16,20 +16,15 @@ import {
|
||||||
isNodeError,
|
isNodeError,
|
||||||
Config,
|
Config,
|
||||||
MessageSenderType,
|
MessageSenderType,
|
||||||
ServerToolCallConfirmationDetails,
|
|
||||||
ToolCallResponseInfo,
|
|
||||||
ToolEditConfirmationDetails,
|
|
||||||
ToolExecuteConfirmationDetails,
|
|
||||||
ToolResultDisplay,
|
|
||||||
ToolCallRequestInfo,
|
ToolCallRequestInfo,
|
||||||
} from '@gemini-code/core';
|
} from '@gemini-code/core';
|
||||||
import { type PartListUnion, type Part } from '@google/genai';
|
import { type PartListUnion } from '@google/genai';
|
||||||
import {
|
import {
|
||||||
StreamingState,
|
StreamingState,
|
||||||
ToolCallStatus,
|
|
||||||
HistoryItemWithoutId,
|
HistoryItemWithoutId,
|
||||||
HistoryItemToolGroup,
|
HistoryItemToolGroup,
|
||||||
MessageType,
|
MessageType,
|
||||||
|
ToolCallStatus,
|
||||||
} from '../types.js';
|
} from '../types.js';
|
||||||
import { isAtCommand } from '../utils/commandUtils.js';
|
import { isAtCommand } from '../utils/commandUtils.js';
|
||||||
import { useShellCommandProcessor } from './shellCommandProcessor.js';
|
import { useShellCommandProcessor } from './shellCommandProcessor.js';
|
||||||
|
@ -38,7 +33,13 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||||
import { useStateAndRef } from './useStateAndRef.js';
|
import { useStateAndRef } from './useStateAndRef.js';
|
||||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||||
import { useLogger } from './useLogger.js';
|
import { useLogger } from './useLogger.js';
|
||||||
import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
|
import {
|
||||||
|
useReactToolScheduler,
|
||||||
|
mapToDisplay as mapTrackedToolCallsToDisplay,
|
||||||
|
TrackedToolCall,
|
||||||
|
TrackedCompletedToolCall,
|
||||||
|
TrackedCancelledToolCall,
|
||||||
|
} from './useReactToolScheduler.js';
|
||||||
import { GeminiChat } from '@gemini-code/core/src/core/geminiChat.js';
|
import { GeminiChat } from '@gemini-code/core/src/core/geminiChat.js';
|
||||||
|
|
||||||
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
|
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
|
||||||
|
@ -60,12 +61,11 @@ enum StreamProcessingStatus {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Hook to manage the Gemini stream, handle user input, process commands,
|
* Manages the Gemini stream, including user input, command processing,
|
||||||
* and interact with the Gemini API and history manager.
|
* API interaction, and tool call lifecycle.
|
||||||
*/
|
*/
|
||||||
export const useGeminiStream = (
|
export const useGeminiStream = (
|
||||||
addItem: UseHistoryManagerReturn['addItem'],
|
addItem: UseHistoryManagerReturn['addItem'],
|
||||||
refreshStatic: () => void,
|
|
||||||
setShowHelp: React.Dispatch<React.SetStateAction<boolean>>,
|
setShowHelp: React.Dispatch<React.SetStateAction<boolean>>,
|
||||||
config: Config,
|
config: Config,
|
||||||
onDebugMessage: (message: string) => void,
|
onDebugMessage: (message: string) => void,
|
||||||
|
@ -82,27 +82,33 @@ export const useGeminiStream = (
|
||||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||||
const logger = useLogger();
|
const logger = useLogger();
|
||||||
const [toolCalls, schedule, cancel] = useToolScheduler(
|
|
||||||
(tools) => {
|
|
||||||
if (tools.length) {
|
|
||||||
addItem(mapToDisplay(tools), Date.now());
|
|
||||||
const toolResponses = tools
|
|
||||||
.filter(
|
|
||||||
(t) =>
|
|
||||||
t.status === 'error' ||
|
|
||||||
t.status === 'cancelled' ||
|
|
||||||
t.status === 'success',
|
|
||||||
)
|
|
||||||
.map((t) => t.response.responseParts);
|
|
||||||
|
|
||||||
submitQuery(mergePartListUnions(toolResponses));
|
const [
|
||||||
|
toolCalls,
|
||||||
|
scheduleToolCalls,
|
||||||
|
cancelAllToolCalls,
|
||||||
|
markToolsAsSubmitted,
|
||||||
|
] = useReactToolScheduler(
|
||||||
|
(completedToolCallsFromScheduler) => {
|
||||||
|
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||||
|
if (completedToolCallsFromScheduler.length > 0) {
|
||||||
|
// Add the final state of these tools to the history for display.
|
||||||
|
// The new useEffect will handle submitting their responses.
|
||||||
|
addItem(
|
||||||
|
mapTrackedToolCallsToDisplay(
|
||||||
|
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||||
|
),
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
config,
|
config,
|
||||||
setPendingHistoryItem,
|
setPendingHistoryItem,
|
||||||
);
|
);
|
||||||
const pendingToolCalls = useMemo(
|
|
||||||
() => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
|
const pendingToolCallGroupDisplay = useMemo(
|
||||||
|
() =>
|
||||||
|
toolCalls.length ? mapTrackedToolCallsToDisplay(toolCalls) : undefined,
|
||||||
[toolCalls],
|
[toolCalls],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -120,16 +126,16 @@ export const useGeminiStream = (
|
||||||
);
|
);
|
||||||
|
|
||||||
const streamingState = useMemo(() => {
|
const streamingState = useMemo(() => {
|
||||||
if (toolCalls.some((t) => t.status === 'awaiting_approval')) {
|
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
|
||||||
return StreamingState.WaitingForConfirmation;
|
return StreamingState.WaitingForConfirmation;
|
||||||
}
|
}
|
||||||
if (
|
if (
|
||||||
isResponding ||
|
isResponding ||
|
||||||
toolCalls.some(
|
toolCalls.some(
|
||||||
(t) =>
|
(tc) =>
|
||||||
t.status === 'executing' ||
|
tc.status === 'executing' ||
|
||||||
t.status === 'scheduled' ||
|
tc.status === 'scheduled' ||
|
||||||
t.status === 'validating',
|
tc.status === 'validating',
|
||||||
)
|
)
|
||||||
) {
|
) {
|
||||||
return StreamingState.Responding;
|
return StreamingState.Responding;
|
||||||
|
@ -153,7 +159,7 @@ export const useGeminiStream = (
|
||||||
useInput((_input, key) => {
|
useInput((_input, key) => {
|
||||||
if (streamingState !== StreamingState.Idle && key.escape) {
|
if (streamingState !== StreamingState.Idle && key.escape) {
|
||||||
abortControllerRef.current?.abort();
|
abortControllerRef.current?.abort();
|
||||||
cancel();
|
cancelAllToolCalls(); // Also cancel any pending/executing tool calls
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -194,7 +200,7 @@ export const useGeminiStream = (
|
||||||
name: toolName,
|
name: toolName,
|
||||||
args: toolArgs,
|
args: toolArgs,
|
||||||
};
|
};
|
||||||
schedule([toolCallRequest]); // schedule expects an array or single object
|
scheduleToolCalls([toolCallRequest]);
|
||||||
}
|
}
|
||||||
return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool
|
return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool
|
||||||
}
|
}
|
||||||
|
@ -246,7 +252,7 @@ export const useGeminiStream = (
|
||||||
handleSlashCommand,
|
handleSlashCommand,
|
||||||
logger,
|
logger,
|
||||||
shellModeActive,
|
shellModeActive,
|
||||||
schedule,
|
scheduleToolCalls,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -275,73 +281,6 @@ export const useGeminiStream = (
|
||||||
return { client: currentClient, chat: chatSessionRef.current };
|
return { client: currentClient, chat: chatSessionRef.current };
|
||||||
}, [addItem]);
|
}, [addItem]);
|
||||||
|
|
||||||
// --- UI Helper Functions (used by event handlers) ---
|
|
||||||
const updateFunctionResponseUI = (
|
|
||||||
toolResponse: ToolCallResponseInfo,
|
|
||||||
status: ToolCallStatus,
|
|
||||||
) => {
|
|
||||||
setPendingHistoryItem((item) =>
|
|
||||||
item?.type === 'tool_group'
|
|
||||||
? {
|
|
||||||
...item,
|
|
||||||
tools: item.tools.map((tool) =>
|
|
||||||
tool.callId === toolResponse.callId
|
|
||||||
? {
|
|
||||||
...tool,
|
|
||||||
status,
|
|
||||||
resultDisplay: toolResponse.resultDisplay,
|
|
||||||
}
|
|
||||||
: tool,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
: item,
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
|
|
||||||
// or could be a standalone helper if more params are passed.
|
|
||||||
// TODO: handle file diff result display stuff
|
|
||||||
function _declineToolExecution(
|
|
||||||
declineMessage: string,
|
|
||||||
status: ToolCallStatus,
|
|
||||||
request: ServerToolCallConfirmationDetails['request'],
|
|
||||||
originalDetails: ServerToolCallConfirmationDetails['details'],
|
|
||||||
) {
|
|
||||||
let resultDisplay: ToolResultDisplay | undefined;
|
|
||||||
if ('fileDiff' in originalDetails) {
|
|
||||||
resultDisplay = {
|
|
||||||
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
|
|
||||||
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
|
|
||||||
}
|
|
||||||
const functionResponse: Part = {
|
|
||||||
functionResponse: {
|
|
||||||
id: request.callId,
|
|
||||||
name: request.name,
|
|
||||||
response: { error: declineMessage },
|
|
||||||
},
|
|
||||||
};
|
|
||||||
const responseInfo: ToolCallResponseInfo = {
|
|
||||||
callId: request.callId,
|
|
||||||
responseParts: functionResponse,
|
|
||||||
resultDisplay,
|
|
||||||
error: new Error(declineMessage),
|
|
||||||
};
|
|
||||||
const history = chatSessionRef.current?.getHistory();
|
|
||||||
if (history) {
|
|
||||||
history.push({ role: 'model', parts: [functionResponse] });
|
|
||||||
}
|
|
||||||
updateFunctionResponseUI(responseInfo, status);
|
|
||||||
|
|
||||||
if (pendingHistoryItemRef.current) {
|
|
||||||
addItem(pendingHistoryItemRef.current, Date.now());
|
|
||||||
setPendingHistoryItem(null);
|
|
||||||
}
|
|
||||||
setIsResponding(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Stream Event Handlers ---
|
// --- Stream Event Handlers ---
|
||||||
|
|
||||||
const handleContentEvent = useCallback(
|
const handleContentEvent = useCallback(
|
||||||
|
@ -425,9 +364,9 @@ export const useGeminiStream = (
|
||||||
userMessageTimestamp,
|
userMessageTimestamp,
|
||||||
);
|
);
|
||||||
setIsResponding(false);
|
setIsResponding(false);
|
||||||
cancel();
|
cancelAllToolCalls();
|
||||||
},
|
},
|
||||||
[addItem, pendingHistoryItemRef, setPendingHistoryItem, cancel],
|
[addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls],
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleErrorEvent = useCallback(
|
const handleErrorEvent = useCallback(
|
||||||
|
@ -462,22 +401,22 @@ export const useGeminiStream = (
|
||||||
toolCallRequests.push(event.value);
|
toolCallRequests.push(event.value);
|
||||||
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
||||||
handleUserCancelledEvent(userMessageTimestamp);
|
handleUserCancelledEvent(userMessageTimestamp);
|
||||||
cancel();
|
|
||||||
return StreamProcessingStatus.UserCancelled;
|
return StreamProcessingStatus.UserCancelled;
|
||||||
} else if (event.type === ServerGeminiEventType.Error) {
|
} else if (event.type === ServerGeminiEventType.Error) {
|
||||||
handleErrorEvent(event.value, userMessageTimestamp);
|
handleErrorEvent(event.value, userMessageTimestamp);
|
||||||
return StreamProcessingStatus.Error;
|
return StreamProcessingStatus.Error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
schedule(toolCallRequests);
|
if (toolCallRequests.length > 0) {
|
||||||
|
scheduleToolCalls(toolCallRequests);
|
||||||
|
}
|
||||||
return StreamProcessingStatus.Completed;
|
return StreamProcessingStatus.Completed;
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
handleContentEvent,
|
handleContentEvent,
|
||||||
handleUserCancelledEvent,
|
handleUserCancelledEvent,
|
||||||
cancel,
|
|
||||||
handleErrorEvent,
|
handleErrorEvent,
|
||||||
schedule,
|
scheduleToolCalls,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -545,21 +484,69 @@ export const useGeminiStream = (
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
setShowHelp,
|
|
||||||
addItem,
|
|
||||||
setInitError,
|
|
||||||
ensureChatSession,
|
|
||||||
prepareQueryForGemini,
|
|
||||||
processGeminiStreamEvents,
|
|
||||||
setPendingHistoryItem,
|
|
||||||
pendingHistoryItemRef,
|
|
||||||
streamingState,
|
streamingState,
|
||||||
|
setShowHelp,
|
||||||
|
prepareQueryForGemini,
|
||||||
|
ensureChatSession,
|
||||||
|
processGeminiStreamEvents,
|
||||||
|
pendingHistoryItemRef,
|
||||||
|
addItem,
|
||||||
|
setPendingHistoryItem,
|
||||||
|
setInitError,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Automatically submits responses for completed tool calls.
|
||||||
|
* This effect runs when `toolCalls` or `isResponding` changes.
|
||||||
|
* It ensures that tool responses are sent back to Gemini only when
|
||||||
|
* all processing for a given set of tools is finished and Gemini
|
||||||
|
* is not already generating a response.
|
||||||
|
*/
|
||||||
|
useEffect(() => {
|
||||||
|
if (isResponding) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const completedAndReadyToSubmitTools = toolCalls.filter(
|
||||||
|
(
|
||||||
|
tc: TrackedToolCall,
|
||||||
|
): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => {
|
||||||
|
const isTerminalState =
|
||||||
|
tc.status === 'success' ||
|
||||||
|
tc.status === 'error' ||
|
||||||
|
tc.status === 'cancelled';
|
||||||
|
|
||||||
|
if (isTerminalState) {
|
||||||
|
const completedOrCancelledCall = tc as
|
||||||
|
| TrackedCompletedToolCall
|
||||||
|
| TrackedCancelledToolCall;
|
||||||
|
return (
|
||||||
|
!completedOrCancelledCall.responseSubmittedToGemini &&
|
||||||
|
completedOrCancelledCall.response?.responseParts !== undefined
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
if (completedAndReadyToSubmitTools.length > 0) {
|
||||||
|
const responsesToSend: PartListUnion[] =
|
||||||
|
completedAndReadyToSubmitTools.map(
|
||||||
|
(toolCall) => toolCall.response.responseParts,
|
||||||
|
);
|
||||||
|
const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
|
||||||
|
(toolCall) => toolCall.request.callId,
|
||||||
|
);
|
||||||
|
|
||||||
|
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||||
|
submitQuery(mergePartListUnions(responsesToSend));
|
||||||
|
}
|
||||||
|
}, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]);
|
||||||
|
|
||||||
const pendingHistoryItems = [
|
const pendingHistoryItems = [
|
||||||
pendingHistoryItemRef.current,
|
pendingHistoryItemRef.current,
|
||||||
pendingToolCalls,
|
pendingToolCallGroupDisplay,
|
||||||
].filter((i) => i !== undefined && i !== null);
|
].filter((i) => i !== undefined && i !== null);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -0,0 +1,301 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
Config,
|
||||||
|
ToolCallRequestInfo,
|
||||||
|
ExecutingToolCall,
|
||||||
|
ScheduledToolCall,
|
||||||
|
ValidatingToolCall,
|
||||||
|
WaitingToolCall,
|
||||||
|
CompletedToolCall,
|
||||||
|
CancelledToolCall,
|
||||||
|
CoreToolScheduler,
|
||||||
|
OutputUpdateHandler,
|
||||||
|
AllToolCallsCompleteHandler,
|
||||||
|
ToolCallsUpdateHandler,
|
||||||
|
Tool,
|
||||||
|
ToolCall,
|
||||||
|
Status as CoreStatus,
|
||||||
|
} from '@gemini-code/core';
|
||||||
|
import { useCallback, useEffect, useState, useRef } from 'react';
|
||||||
|
import {
|
||||||
|
HistoryItemToolGroup,
|
||||||
|
IndividualToolCallDisplay,
|
||||||
|
ToolCallStatus,
|
||||||
|
HistoryItemWithoutId,
|
||||||
|
} from '../types.js';
|
||||||
|
|
||||||
|
export type ScheduleFn = (
|
||||||
|
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||||
|
) => void;
|
||||||
|
export type CancelFn = (reason?: string) => void;
|
||||||
|
export type MarkToolsAsSubmittedFn = (callIds: string[]) => void;
|
||||||
|
|
||||||
|
export type TrackedScheduledToolCall = ScheduledToolCall & {
|
||||||
|
responseSubmittedToGemini?: boolean;
|
||||||
|
};
|
||||||
|
export type TrackedValidatingToolCall = ValidatingToolCall & {
|
||||||
|
responseSubmittedToGemini?: boolean;
|
||||||
|
};
|
||||||
|
export type TrackedWaitingToolCall = WaitingToolCall & {
|
||||||
|
responseSubmittedToGemini?: boolean;
|
||||||
|
};
|
||||||
|
export type TrackedExecutingToolCall = ExecutingToolCall & {
|
||||||
|
responseSubmittedToGemini?: boolean;
|
||||||
|
};
|
||||||
|
export type TrackedCompletedToolCall = CompletedToolCall & {
|
||||||
|
responseSubmittedToGemini?: boolean;
|
||||||
|
};
|
||||||
|
export type TrackedCancelledToolCall = CancelledToolCall & {
|
||||||
|
responseSubmittedToGemini?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type TrackedToolCall =
|
||||||
|
| TrackedScheduledToolCall
|
||||||
|
| TrackedValidatingToolCall
|
||||||
|
| TrackedWaitingToolCall
|
||||||
|
| TrackedExecutingToolCall
|
||||||
|
| TrackedCompletedToolCall
|
||||||
|
| TrackedCancelledToolCall;
|
||||||
|
|
||||||
|
export function useReactToolScheduler(
|
||||||
|
onComplete: (tools: CompletedToolCall[]) => void,
|
||||||
|
config: Config,
|
||||||
|
setPendingHistoryItem: React.Dispatch<
|
||||||
|
React.SetStateAction<HistoryItemWithoutId | null>
|
||||||
|
>,
|
||||||
|
): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] {
|
||||||
|
const [toolCallsForDisplay, setToolCallsForDisplay] = useState<
|
||||||
|
TrackedToolCall[]
|
||||||
|
>([]);
|
||||||
|
const schedulerRef = useRef<CoreToolScheduler | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const outputUpdateHandler: OutputUpdateHandler = (
|
||||||
|
toolCallId,
|
||||||
|
outputChunk,
|
||||||
|
) => {
|
||||||
|
setPendingHistoryItem((prevItem) => {
|
||||||
|
if (prevItem?.type === 'tool_group') {
|
||||||
|
return {
|
||||||
|
...prevItem,
|
||||||
|
tools: prevItem.tools.map((toolDisplay) =>
|
||||||
|
toolDisplay.callId === toolCallId &&
|
||||||
|
toolDisplay.status === ToolCallStatus.Executing
|
||||||
|
? { ...toolDisplay, resultDisplay: outputChunk }
|
||||||
|
: toolDisplay,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return prevItem;
|
||||||
|
});
|
||||||
|
|
||||||
|
setToolCallsForDisplay((prevCalls) =>
|
||||||
|
prevCalls.map((tc) => {
|
||||||
|
if (tc.request.callId === toolCallId && tc.status === 'executing') {
|
||||||
|
const executingTc = tc as TrackedExecutingToolCall;
|
||||||
|
return { ...executingTc, liveOutput: outputChunk };
|
||||||
|
}
|
||||||
|
return tc;
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = (
|
||||||
|
completedToolCalls,
|
||||||
|
) => {
|
||||||
|
onComplete(completedToolCalls);
|
||||||
|
};
|
||||||
|
|
||||||
|
const toolCallsUpdateHandler: ToolCallsUpdateHandler = (
|
||||||
|
updatedCoreToolCalls: ToolCall[],
|
||||||
|
) => {
|
||||||
|
setToolCallsForDisplay((prevTrackedCalls) =>
|
||||||
|
updatedCoreToolCalls.map((coreTc) => {
|
||||||
|
const existingTrackedCall = prevTrackedCalls.find(
|
||||||
|
(ptc) => ptc.request.callId === coreTc.request.callId,
|
||||||
|
);
|
||||||
|
const newTrackedCall: TrackedToolCall = {
|
||||||
|
...coreTc,
|
||||||
|
responseSubmittedToGemini:
|
||||||
|
existingTrackedCall?.responseSubmittedToGemini ?? false,
|
||||||
|
} as TrackedToolCall;
|
||||||
|
return newTrackedCall;
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
schedulerRef.current = new CoreToolScheduler({
|
||||||
|
toolRegistry: config.getToolRegistry(),
|
||||||
|
outputUpdateHandler,
|
||||||
|
onAllToolCallsComplete: allToolCallsCompleteHandler,
|
||||||
|
onToolCallsUpdate: toolCallsUpdateHandler,
|
||||||
|
});
|
||||||
|
}, [config, onComplete, setPendingHistoryItem]);
|
||||||
|
|
||||||
|
const schedule: ScheduleFn = useCallback(
|
||||||
|
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
|
||||||
|
schedulerRef.current?.schedule(request);
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
);
|
||||||
|
|
||||||
|
const cancel: CancelFn = useCallback((reason: string = 'unspecified') => {
|
||||||
|
schedulerRef.current?.cancelAll(reason);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const markToolsAsSubmitted: MarkToolsAsSubmittedFn = useCallback(
|
||||||
|
(callIdsToMark: string[]) => {
|
||||||
|
setToolCallsForDisplay((prevCalls) =>
|
||||||
|
prevCalls.map((tc) =>
|
||||||
|
callIdsToMark.includes(tc.request.callId)
|
||||||
|
? { ...tc, responseSubmittedToGemini: true }
|
||||||
|
: tc,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[],
|
||||||
|
);
|
||||||
|
|
||||||
|
return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Maps a CoreToolScheduler status to the UI's ToolCallStatus enum.
|
||||||
|
*/
|
||||||
|
function mapCoreStatusToDisplayStatus(coreStatus: CoreStatus): ToolCallStatus {
|
||||||
|
switch (coreStatus) {
|
||||||
|
case 'validating':
|
||||||
|
return ToolCallStatus.Executing;
|
||||||
|
case 'awaiting_approval':
|
||||||
|
return ToolCallStatus.Confirming;
|
||||||
|
case 'executing':
|
||||||
|
return ToolCallStatus.Executing;
|
||||||
|
case 'success':
|
||||||
|
return ToolCallStatus.Success;
|
||||||
|
case 'cancelled':
|
||||||
|
return ToolCallStatus.Canceled;
|
||||||
|
case 'error':
|
||||||
|
return ToolCallStatus.Error;
|
||||||
|
case 'scheduled':
|
||||||
|
return ToolCallStatus.Pending;
|
||||||
|
default: {
|
||||||
|
const exhaustiveCheck: never = coreStatus;
|
||||||
|
console.warn(`Unknown core status encountered: ${exhaustiveCheck}`);
|
||||||
|
return ToolCallStatus.Error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Transforms `TrackedToolCall` objects into `HistoryItemToolGroup` objects for UI display.
|
||||||
|
*/
|
||||||
|
export function mapToDisplay(
|
||||||
|
toolOrTools: TrackedToolCall[] | TrackedToolCall,
|
||||||
|
): HistoryItemToolGroup {
|
||||||
|
const toolCalls = Array.isArray(toolOrTools) ? toolOrTools : [toolOrTools];
|
||||||
|
|
||||||
|
const toolDisplays = toolCalls.map(
|
||||||
|
(trackedCall): IndividualToolCallDisplay => {
|
||||||
|
let displayName = trackedCall.request.name;
|
||||||
|
let description = '';
|
||||||
|
let renderOutputAsMarkdown = false;
|
||||||
|
|
||||||
|
const currentToolInstance =
|
||||||
|
'tool' in trackedCall && trackedCall.tool
|
||||||
|
? (trackedCall as { tool: Tool }).tool
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
if (currentToolInstance) {
|
||||||
|
displayName = currentToolInstance.displayName;
|
||||||
|
description = currentToolInstance.getDescription(
|
||||||
|
trackedCall.request.args,
|
||||||
|
);
|
||||||
|
renderOutputAsMarkdown = currentToolInstance.isOutputMarkdown;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (trackedCall.status === 'error') {
|
||||||
|
description = '';
|
||||||
|
}
|
||||||
|
|
||||||
|
const baseDisplayProperties: Omit<
|
||||||
|
IndividualToolCallDisplay,
|
||||||
|
'status' | 'resultDisplay' | 'confirmationDetails'
|
||||||
|
> = {
|
||||||
|
callId: trackedCall.request.callId,
|
||||||
|
name: displayName,
|
||||||
|
description,
|
||||||
|
renderOutputAsMarkdown,
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (trackedCall.status) {
|
||||||
|
case 'success':
|
||||||
|
return {
|
||||||
|
...baseDisplayProperties,
|
||||||
|
status: mapCoreStatusToDisplayStatus(trackedCall.status),
|
||||||
|
resultDisplay: trackedCall.response.resultDisplay,
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
case 'error':
|
||||||
|
return {
|
||||||
|
...baseDisplayProperties,
|
||||||
|
name: currentToolInstance?.displayName ?? trackedCall.request.name,
|
||||||
|
status: mapCoreStatusToDisplayStatus(trackedCall.status),
|
||||||
|
resultDisplay: trackedCall.response.resultDisplay,
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
case 'cancelled':
|
||||||
|
return {
|
||||||
|
...baseDisplayProperties,
|
||||||
|
status: mapCoreStatusToDisplayStatus(trackedCall.status),
|
||||||
|
resultDisplay: trackedCall.response.resultDisplay,
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
case 'awaiting_approval':
|
||||||
|
return {
|
||||||
|
...baseDisplayProperties,
|
||||||
|
status: mapCoreStatusToDisplayStatus(trackedCall.status),
|
||||||
|
resultDisplay: undefined,
|
||||||
|
confirmationDetails: trackedCall.confirmationDetails,
|
||||||
|
};
|
||||||
|
case 'executing':
|
||||||
|
return {
|
||||||
|
...baseDisplayProperties,
|
||||||
|
status: mapCoreStatusToDisplayStatus(trackedCall.status),
|
||||||
|
resultDisplay:
|
||||||
|
(trackedCall as TrackedExecutingToolCall).liveOutput ?? undefined,
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
case 'validating': // Fallthrough
|
||||||
|
case 'scheduled':
|
||||||
|
return {
|
||||||
|
...baseDisplayProperties,
|
||||||
|
status: mapCoreStatusToDisplayStatus(trackedCall.status),
|
||||||
|
resultDisplay: undefined,
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
default: {
|
||||||
|
const exhaustiveCheck: never = trackedCall;
|
||||||
|
return {
|
||||||
|
callId: (exhaustiveCheck as TrackedToolCall).request.callId,
|
||||||
|
name: 'Unknown Tool',
|
||||||
|
description: 'Encountered an unknown tool call state.',
|
||||||
|
status: ToolCallStatus.Error,
|
||||||
|
resultDisplay: 'Unknown tool call state',
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
renderOutputAsMarkdown: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: 'tool_group',
|
||||||
|
tools: toolDisplays,
|
||||||
|
};
|
||||||
|
}
|
|
@ -8,12 +8,9 @@
|
||||||
import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, afterEach, Mock } from 'vitest';
|
||||||
import { renderHook, act } from '@testing-library/react';
|
import { renderHook, act } from '@testing-library/react';
|
||||||
import {
|
import {
|
||||||
useToolScheduler,
|
useReactToolScheduler,
|
||||||
formatLlmContentForFunctionResponse,
|
|
||||||
mapToDisplay,
|
mapToDisplay,
|
||||||
ToolCall,
|
} from './useReactToolScheduler.js';
|
||||||
Status as ToolCallStatusType, // Renamed to avoid conflict
|
|
||||||
} from './useToolScheduler.js';
|
|
||||||
import {
|
import {
|
||||||
Part,
|
Part,
|
||||||
PartListUnion,
|
PartListUnion,
|
||||||
|
@ -29,6 +26,9 @@ import {
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
ToolConfirmationOutcome,
|
ToolConfirmationOutcome,
|
||||||
ToolCallResponseInfo,
|
ToolCallResponseInfo,
|
||||||
|
formatLlmContentForFunctionResponse, // Import from core
|
||||||
|
ToolCall, // Import from core
|
||||||
|
Status as ToolCallStatusType, // Import from core
|
||||||
} from '@gemini-code/core';
|
} from '@gemini-code/core';
|
||||||
import {
|
import {
|
||||||
HistoryItemWithoutId,
|
HistoryItemWithoutId,
|
||||||
|
@ -205,7 +205,7 @@ describe('formatLlmContentForFunctionResponse', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('useToolScheduler', () => {
|
describe('useReactToolScheduler', () => {
|
||||||
// TODO(ntaylormullen): The following tests are skipped due to difficulties in
|
// TODO(ntaylormullen): The following tests are skipped due to difficulties in
|
||||||
// reliably testing the asynchronous state updates and interactions with timers.
|
// reliably testing the asynchronous state updates and interactions with timers.
|
||||||
// These tests involve complex sequences of events, including confirmations,
|
// These tests involve complex sequences of events, including confirmations,
|
||||||
|
@ -276,7 +276,7 @@ describe('useToolScheduler', () => {
|
||||||
|
|
||||||
const renderScheduler = () =>
|
const renderScheduler = () =>
|
||||||
renderHook(() =>
|
renderHook(() =>
|
||||||
useToolScheduler(
|
useReactToolScheduler(
|
||||||
onComplete,
|
onComplete,
|
||||||
mockConfig as unknown as Config,
|
mockConfig as unknown as Config,
|
||||||
setPendingHistoryItem,
|
setPendingHistoryItem,
|
||||||
|
@ -367,7 +367,7 @@ describe('useToolScheduler', () => {
|
||||||
request,
|
request,
|
||||||
response: expect.objectContaining({
|
response: expect.objectContaining({
|
||||||
error: expect.objectContaining({
|
error: expect.objectContaining({
|
||||||
message: 'tool nonExistentTool does not exist',
|
message: 'Tool "nonExistentTool" not found in registry.',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
@ -1050,7 +1050,7 @@ describe('mapToDisplay', () => {
|
||||||
},
|
},
|
||||||
expectedStatus: ToolCallStatus.Error,
|
expectedStatus: ToolCallStatus.Error,
|
||||||
expectedResultDisplay: 'Execution failed display',
|
expectedResultDisplay: 'Execution failed display',
|
||||||
expectedName: baseTool.name,
|
expectedName: baseTool.displayName, // Changed from baseTool.name
|
||||||
expectedDescription: '',
|
expectedDescription: '',
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,626 +0,0 @@
|
||||||
/**
|
|
||||||
* @license
|
|
||||||
* Copyright 2025 Google LLC
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
*/
|
|
||||||
|
|
||||||
import {
|
|
||||||
Config,
|
|
||||||
ToolCallRequestInfo,
|
|
||||||
ToolCallResponseInfo,
|
|
||||||
ToolConfirmationOutcome,
|
|
||||||
Tool,
|
|
||||||
ToolCallConfirmationDetails,
|
|
||||||
ToolResult,
|
|
||||||
} from '@gemini-code/core';
|
|
||||||
import { Part, PartUnion, PartListUnion } from '@google/genai';
|
|
||||||
import { useCallback, useEffect, useState } from 'react';
|
|
||||||
import {
|
|
||||||
HistoryItemToolGroup,
|
|
||||||
IndividualToolCallDisplay,
|
|
||||||
ToolCallStatus,
|
|
||||||
HistoryItemWithoutId,
|
|
||||||
} from '../types.js';
|
|
||||||
|
|
||||||
type ValidatingToolCall = {
|
|
||||||
status: 'validating';
|
|
||||||
request: ToolCallRequestInfo;
|
|
||||||
tool: Tool;
|
|
||||||
};
|
|
||||||
|
|
||||||
type ScheduledToolCall = {
|
|
||||||
status: 'scheduled';
|
|
||||||
request: ToolCallRequestInfo;
|
|
||||||
tool: Tool;
|
|
||||||
};
|
|
||||||
|
|
||||||
type ErroredToolCall = {
|
|
||||||
status: 'error';
|
|
||||||
request: ToolCallRequestInfo;
|
|
||||||
response: ToolCallResponseInfo;
|
|
||||||
};
|
|
||||||
|
|
||||||
type SuccessfulToolCall = {
|
|
||||||
status: 'success';
|
|
||||||
request: ToolCallRequestInfo;
|
|
||||||
tool: Tool;
|
|
||||||
response: ToolCallResponseInfo;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type ExecutingToolCall = {
|
|
||||||
status: 'executing';
|
|
||||||
request: ToolCallRequestInfo;
|
|
||||||
tool: Tool;
|
|
||||||
liveOutput?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
type CancelledToolCall = {
|
|
||||||
status: 'cancelled';
|
|
||||||
request: ToolCallRequestInfo;
|
|
||||||
response: ToolCallResponseInfo;
|
|
||||||
tool: Tool;
|
|
||||||
};
|
|
||||||
|
|
||||||
type WaitingToolCall = {
|
|
||||||
status: 'awaiting_approval';
|
|
||||||
request: ToolCallRequestInfo;
|
|
||||||
tool: Tool;
|
|
||||||
confirmationDetails: ToolCallConfirmationDetails;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type Status = ToolCall['status'];
|
|
||||||
|
|
||||||
export type ToolCall =
|
|
||||||
| ValidatingToolCall
|
|
||||||
| ScheduledToolCall
|
|
||||||
| ErroredToolCall
|
|
||||||
| SuccessfulToolCall
|
|
||||||
| ExecutingToolCall
|
|
||||||
| CancelledToolCall
|
|
||||||
| WaitingToolCall;
|
|
||||||
|
|
||||||
export type ScheduleFn = (
|
|
||||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
|
||||||
) => void;
|
|
||||||
export type CancelFn = () => void;
|
|
||||||
export type CompletedToolCall =
|
|
||||||
| SuccessfulToolCall
|
|
||||||
| CancelledToolCall
|
|
||||||
| ErroredToolCall;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Formats a PartListUnion response from a tool into JSON suitable for a Gemini
|
|
||||||
* FunctionResponse and additional Parts to include after that response.
|
|
||||||
*
|
|
||||||
* This is required because FunctionReponse appears to only support JSON
|
|
||||||
* and not arbitrary parts. Including parts like inlineData or fileData
|
|
||||||
* directly in a FunctionResponse confuses the model resulting in a failure
|
|
||||||
* to interpret the multimodal content and context window exceeded errors.
|
|
||||||
*/
|
|
||||||
|
|
||||||
export function formatLlmContentForFunctionResponse(
|
|
||||||
llmContent: PartListUnion,
|
|
||||||
): {
|
|
||||||
functionResponseJson: Record<string, string>;
|
|
||||||
additionalParts: PartUnion[];
|
|
||||||
} {
|
|
||||||
const additionalParts: PartUnion[] = [];
|
|
||||||
let functionResponseJson: Record<string, string>;
|
|
||||||
|
|
||||||
if (Array.isArray(llmContent) && llmContent.length === 1) {
|
|
||||||
// Ensure that length 1 arrays are treated as a single Part.
|
|
||||||
llmContent = llmContent[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (typeof llmContent === 'string') {
|
|
||||||
functionResponseJson = { output: llmContent };
|
|
||||||
} else if (Array.isArray(llmContent)) {
|
|
||||||
functionResponseJson = { status: 'Tool execution succeeded.' };
|
|
||||||
additionalParts.push(...llmContent);
|
|
||||||
} else {
|
|
||||||
if (
|
|
||||||
llmContent.inlineData !== undefined ||
|
|
||||||
llmContent.fileData !== undefined
|
|
||||||
) {
|
|
||||||
// For Parts like inlineData or fileData, use the returnDisplay as the textual output for the functionResponse.
|
|
||||||
// The actual Part will be added to additionalParts.
|
|
||||||
functionResponseJson = {
|
|
||||||
status: `Binary content of type ${llmContent.inlineData?.mimeType || llmContent.fileData?.mimeType || 'unknown'} was processed.`,
|
|
||||||
};
|
|
||||||
additionalParts.push(llmContent);
|
|
||||||
} else if (llmContent.text !== undefined) {
|
|
||||||
functionResponseJson = { output: llmContent.text };
|
|
||||||
} else {
|
|
||||||
functionResponseJson = { status: 'Tool execution succeeded.' };
|
|
||||||
additionalParts.push(llmContent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
functionResponseJson,
|
|
||||||
additionalParts,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useToolScheduler(
|
|
||||||
onComplete: (tools: CompletedToolCall[]) => void,
|
|
||||||
config: Config,
|
|
||||||
setPendingHistoryItem: React.Dispatch<
|
|
||||||
React.SetStateAction<HistoryItemWithoutId | null>
|
|
||||||
>,
|
|
||||||
): [ToolCall[], ScheduleFn, CancelFn] {
|
|
||||||
const [toolRegistry] = useState(() => config.getToolRegistry());
|
|
||||||
const [toolCalls, setToolCalls] = useState<ToolCall[]>([]);
|
|
||||||
const [abortController, setAbortController] = useState<AbortController>(
|
|
||||||
() => new AbortController(),
|
|
||||||
);
|
|
||||||
|
|
||||||
const isRunning = toolCalls.some(
|
|
||||||
(t) => t.status === 'executing' || t.status === 'awaiting_approval',
|
|
||||||
);
|
|
||||||
// Note: request array[] typically signal pending tool calls
|
|
||||||
const schedule = useCallback(
|
|
||||||
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
|
|
||||||
if (isRunning) {
|
|
||||||
throw new Error(
|
|
||||||
'Cannot schedule tool calls while other tool calls are running',
|
|
||||||
);
|
|
||||||
}
|
|
||||||
const requestsToProcess = Array.isArray(request) ? request : [request];
|
|
||||||
|
|
||||||
// Step 1: Create initial calls with 'validating' status (or 'error' if tool not found)
|
|
||||||
// and add them to the state immediately to make the UI busy.
|
|
||||||
const initialNewCalls: ToolCall[] = requestsToProcess.map(
|
|
||||||
(r): ToolCall => {
|
|
||||||
const tool = toolRegistry.getTool(r.name);
|
|
||||||
if (!tool) {
|
|
||||||
return {
|
|
||||||
status: 'error',
|
|
||||||
request: r,
|
|
||||||
response: toolErrorResponse(
|
|
||||||
r,
|
|
||||||
new Error(`tool ${r.name} does not exist`),
|
|
||||||
),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
// Set to 'validating' immediately. This will make streamingState 'Responding'.
|
|
||||||
return { status: 'validating', request: r, tool };
|
|
||||||
},
|
|
||||||
);
|
|
||||||
setToolCalls((prevCalls) => prevCalls.concat(initialNewCalls));
|
|
||||||
|
|
||||||
// Step 2: Asynchronously check for confirmation and update status for each new call.
|
|
||||||
initialNewCalls.forEach(async (initialCall) => {
|
|
||||||
// If the call was already marked as an error (tool not found), skip further processing.
|
|
||||||
if (initialCall.status !== 'validating') return;
|
|
||||||
|
|
||||||
const { request: r, tool } = initialCall;
|
|
||||||
try {
|
|
||||||
const userApproval = await tool.shouldConfirmExecute(
|
|
||||||
r.args,
|
|
||||||
abortController.signal,
|
|
||||||
);
|
|
||||||
if (userApproval) {
|
|
||||||
// Confirmation is needed. Update status to 'awaiting_approval'.
|
|
||||||
setToolCalls(
|
|
||||||
setStatus(r.callId, 'awaiting_approval', {
|
|
||||||
...userApproval,
|
|
||||||
onConfirm: async (outcome) => {
|
|
||||||
// This onConfirm is triggered by user interaction later.
|
|
||||||
await userApproval.onConfirm(outcome);
|
|
||||||
setToolCalls(
|
|
||||||
outcome === ToolConfirmationOutcome.Cancel
|
|
||||||
? setStatus(
|
|
||||||
r.callId,
|
|
||||||
'cancelled',
|
|
||||||
'User did not allow tool call',
|
|
||||||
)
|
|
||||||
: // If confirmed, it goes to 'scheduled' to be picked up by the execution effect.
|
|
||||||
setStatus(r.callId, 'scheduled'),
|
|
||||||
);
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
// No confirmation needed, move to 'scheduled' for execution.
|
|
||||||
setToolCalls(setStatus(r.callId, 'scheduled'));
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
// Handle errors from tool.shouldConfirmExecute() itself.
|
|
||||||
setToolCalls(
|
|
||||||
setStatus(
|
|
||||||
r.callId,
|
|
||||||
'error',
|
|
||||||
toolErrorResponse(
|
|
||||||
r,
|
|
||||||
e instanceof Error ? e : new Error(String(e)),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[isRunning, setToolCalls, toolRegistry, abortController.signal],
|
|
||||||
);
|
|
||||||
|
|
||||||
const cancel = useCallback(
|
|
||||||
(reason: string = 'unspecified') => {
|
|
||||||
abortController.abort();
|
|
||||||
setAbortController(new AbortController());
|
|
||||||
setToolCalls((tc) =>
|
|
||||||
tc.map((c) =>
|
|
||||||
c.status !== 'error' && c.status !== 'executing'
|
|
||||||
? {
|
|
||||||
...c,
|
|
||||||
status: 'cancelled',
|
|
||||||
response: {
|
|
||||||
callId: c.request.callId,
|
|
||||||
responseParts: {
|
|
||||||
functionResponse: {
|
|
||||||
id: c.request.callId,
|
|
||||||
name: c.request.name,
|
|
||||||
response: {
|
|
||||||
error: `[Operation Cancelled] Reason: ${reason}`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resultDisplay: undefined,
|
|
||||||
error: undefined,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
: c,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
},
|
|
||||||
[abortController],
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
// effect for executing scheduled tool calls
|
|
||||||
const allToolsConfirmed = toolCalls.every(
|
|
||||||
(t) => t.status === 'scheduled' || t.status === 'cancelled',
|
|
||||||
);
|
|
||||||
if (allToolsConfirmed) {
|
|
||||||
const signal = abortController.signal;
|
|
||||||
toolCalls
|
|
||||||
.filter((t) => t.status === 'scheduled')
|
|
||||||
.forEach((t) => {
|
|
||||||
const callId = t.request.callId;
|
|
||||||
setToolCalls(setStatus(t.request.callId, 'executing'));
|
|
||||||
|
|
||||||
const updateOutput = t.tool.canUpdateOutput
|
|
||||||
? (output: string) => {
|
|
||||||
setPendingHistoryItem(
|
|
||||||
(prevItem: HistoryItemWithoutId | null) => {
|
|
||||||
if (prevItem?.type === 'tool_group') {
|
|
||||||
return {
|
|
||||||
...prevItem,
|
|
||||||
tools: prevItem.tools.map(
|
|
||||||
(toolDisplay: IndividualToolCallDisplay) =>
|
|
||||||
toolDisplay.callId === callId &&
|
|
||||||
toolDisplay.status === ToolCallStatus.Executing
|
|
||||||
? {
|
|
||||||
...toolDisplay,
|
|
||||||
resultDisplay: output,
|
|
||||||
}
|
|
||||||
: toolDisplay,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return prevItem;
|
|
||||||
},
|
|
||||||
);
|
|
||||||
// Also update the toolCall itself so that mapToDisplay
|
|
||||||
// can pick up the live output if the item is not pending
|
|
||||||
// (e.g. if it's being re-rendered from history)
|
|
||||||
setToolCalls((prevToolCalls) =>
|
|
||||||
prevToolCalls.map((tc) =>
|
|
||||||
tc.request.callId === callId && tc.status === 'executing'
|
|
||||||
? { ...tc, liveOutput: output }
|
|
||||||
: tc,
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
: undefined;
|
|
||||||
|
|
||||||
t.tool
|
|
||||||
.execute(t.request.args, signal, updateOutput)
|
|
||||||
.then((result: ToolResult) => {
|
|
||||||
if (signal.aborted) {
|
|
||||||
// TODO(jacobr): avoid stringifying the LLM content.
|
|
||||||
setToolCalls(
|
|
||||||
setStatus(callId, 'cancelled', String(result.llmContent)),
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const { functionResponseJson, additionalParts } =
|
|
||||||
formatLlmContentForFunctionResponse(result.llmContent);
|
|
||||||
const functionResponse: Part = {
|
|
||||||
functionResponse: {
|
|
||||||
name: t.request.name,
|
|
||||||
id: callId,
|
|
||||||
response: functionResponseJson,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
const response: ToolCallResponseInfo = {
|
|
||||||
callId,
|
|
||||||
responseParts: [functionResponse, ...additionalParts],
|
|
||||||
resultDisplay: result.returnDisplay,
|
|
||||||
error: undefined,
|
|
||||||
};
|
|
||||||
setToolCalls(setStatus(callId, 'success', response));
|
|
||||||
})
|
|
||||||
.catch((e: Error) =>
|
|
||||||
setToolCalls(
|
|
||||||
setStatus(
|
|
||||||
callId,
|
|
||||||
'error',
|
|
||||||
toolErrorResponse(
|
|
||||||
t.request,
|
|
||||||
e instanceof Error ? e : new Error(String(e)),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, [toolCalls, toolRegistry, abortController.signal, setPendingHistoryItem]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const allDone = toolCalls.every(
|
|
||||||
(t) =>
|
|
||||||
t.status === 'success' ||
|
|
||||||
t.status === 'error' ||
|
|
||||||
t.status === 'cancelled',
|
|
||||||
);
|
|
||||||
if (toolCalls.length && allDone) {
|
|
||||||
setToolCalls([]);
|
|
||||||
onComplete(toolCalls);
|
|
||||||
setAbortController(() => new AbortController());
|
|
||||||
}
|
|
||||||
}, [toolCalls, onComplete]);
|
|
||||||
|
|
||||||
return [toolCalls, schedule, cancel];
|
|
||||||
}
|
|
||||||
|
|
||||||
function setStatus(
|
|
||||||
targetCallId: string,
|
|
||||||
status: 'success',
|
|
||||||
response: ToolCallResponseInfo,
|
|
||||||
): (t: ToolCall[]) => ToolCall[];
|
|
||||||
function setStatus(
|
|
||||||
targetCallId: string,
|
|
||||||
status: 'awaiting_approval',
|
|
||||||
confirm: ToolCallConfirmationDetails,
|
|
||||||
): (t: ToolCall[]) => ToolCall[];
|
|
||||||
function setStatus(
|
|
||||||
targetCallId: string,
|
|
||||||
status: 'error',
|
|
||||||
response: ToolCallResponseInfo,
|
|
||||||
): (t: ToolCall[]) => ToolCall[];
|
|
||||||
function setStatus(
|
|
||||||
targetCallId: string,
|
|
||||||
status: 'cancelled',
|
|
||||||
reason: string,
|
|
||||||
): (t: ToolCall[]) => ToolCall[];
|
|
||||||
function setStatus(
|
|
||||||
targetCallId: string,
|
|
||||||
status: 'executing' | 'scheduled' | 'validating',
|
|
||||||
): (t: ToolCall[]) => ToolCall[];
|
|
||||||
function setStatus(
|
|
||||||
targetCallId: string,
|
|
||||||
status: Status,
|
|
||||||
auxiliaryData?: unknown,
|
|
||||||
): (t: ToolCall[]) => ToolCall[] {
|
|
||||||
return function (tc: ToolCall[]): ToolCall[] {
|
|
||||||
return tc.map((t) => {
|
|
||||||
if (t.request.callId !== targetCallId || t.status === 'error') {
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
switch (status) {
|
|
||||||
case 'success': {
|
|
||||||
const next: SuccessfulToolCall = {
|
|
||||||
...t,
|
|
||||||
status: 'success',
|
|
||||||
response: auxiliaryData as ToolCallResponseInfo,
|
|
||||||
};
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
case 'error': {
|
|
||||||
const next: ErroredToolCall = {
|
|
||||||
...t,
|
|
||||||
status: 'error',
|
|
||||||
response: auxiliaryData as ToolCallResponseInfo,
|
|
||||||
};
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
case 'awaiting_approval': {
|
|
||||||
const next: WaitingToolCall = {
|
|
||||||
...t,
|
|
||||||
status: 'awaiting_approval',
|
|
||||||
confirmationDetails: auxiliaryData as ToolCallConfirmationDetails,
|
|
||||||
};
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
case 'scheduled': {
|
|
||||||
const next: ScheduledToolCall = {
|
|
||||||
...t,
|
|
||||||
status: 'scheduled',
|
|
||||||
};
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
case 'cancelled': {
|
|
||||||
const next: CancelledToolCall = {
|
|
||||||
...t,
|
|
||||||
status: 'cancelled',
|
|
||||||
response: {
|
|
||||||
callId: t.request.callId,
|
|
||||||
responseParts: {
|
|
||||||
functionResponse: {
|
|
||||||
id: t.request.callId,
|
|
||||||
name: t.request.name,
|
|
||||||
response: {
|
|
||||||
error: `[Operation Cancelled] Reason: ${auxiliaryData}`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resultDisplay: undefined,
|
|
||||||
error: undefined,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
case 'validating': {
|
|
||||||
const next: ValidatingToolCall = {
|
|
||||||
...(t as ValidatingToolCall), // Added type assertion for safety
|
|
||||||
status: 'validating',
|
|
||||||
};
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
case 'executing': {
|
|
||||||
const next: ExecutingToolCall = {
|
|
||||||
...t,
|
|
||||||
status: 'executing',
|
|
||||||
};
|
|
||||||
return next;
|
|
||||||
}
|
|
||||||
default: {
|
|
||||||
// ensures every case is checked for above
|
|
||||||
const exhaustiveCheck: never = status;
|
|
||||||
return exhaustiveCheck;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const toolErrorResponse = (
|
|
||||||
request: ToolCallRequestInfo,
|
|
||||||
error: Error,
|
|
||||||
): ToolCallResponseInfo => ({
|
|
||||||
callId: request.callId,
|
|
||||||
error,
|
|
||||||
responseParts: {
|
|
||||||
functionResponse: {
|
|
||||||
id: request.callId,
|
|
||||||
name: request.name,
|
|
||||||
response: { error: error.message },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resultDisplay: error.message,
|
|
||||||
});
|
|
||||||
|
|
||||||
function mapStatus(status: Status): ToolCallStatus {
|
|
||||||
switch (status) {
|
|
||||||
case 'validating':
|
|
||||||
return ToolCallStatus.Executing;
|
|
||||||
case 'awaiting_approval':
|
|
||||||
return ToolCallStatus.Confirming;
|
|
||||||
case 'executing':
|
|
||||||
return ToolCallStatus.Executing;
|
|
||||||
case 'success':
|
|
||||||
return ToolCallStatus.Success;
|
|
||||||
case 'cancelled':
|
|
||||||
return ToolCallStatus.Canceled;
|
|
||||||
case 'error':
|
|
||||||
return ToolCallStatus.Error;
|
|
||||||
case 'scheduled':
|
|
||||||
return ToolCallStatus.Pending;
|
|
||||||
default: {
|
|
||||||
// ensures every case is checked for above
|
|
||||||
const exhaustiveCheck: never = status;
|
|
||||||
return exhaustiveCheck;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// convenient function for callers to map ToolCall back to a HistoryItem
|
|
||||||
export function mapToDisplay(
|
|
||||||
tool: ToolCall[] | ToolCall,
|
|
||||||
): HistoryItemToolGroup {
|
|
||||||
const tools = Array.isArray(tool) ? tool : [tool];
|
|
||||||
const toolsDisplays = tools.map((t): IndividualToolCallDisplay => {
|
|
||||||
switch (t.status) {
|
|
||||||
case 'success':
|
|
||||||
return {
|
|
||||||
callId: t.request.callId,
|
|
||||||
name: t.tool.displayName,
|
|
||||||
description: t.tool.getDescription(t.request.args),
|
|
||||||
resultDisplay: t.response.resultDisplay,
|
|
||||||
status: mapStatus(t.status),
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
renderOutputAsMarkdown: t.tool.isOutputMarkdown,
|
|
||||||
};
|
|
||||||
case 'error':
|
|
||||||
return {
|
|
||||||
callId: t.request.callId,
|
|
||||||
name: t.request.name, // Use request.name as tool might be undefined
|
|
||||||
description: '', // No description available if tool is undefined
|
|
||||||
resultDisplay: t.response.resultDisplay,
|
|
||||||
status: mapStatus(t.status),
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
renderOutputAsMarkdown: false,
|
|
||||||
};
|
|
||||||
case 'cancelled':
|
|
||||||
return {
|
|
||||||
callId: t.request.callId,
|
|
||||||
name: t.tool.displayName,
|
|
||||||
description: t.tool.getDescription(t.request.args),
|
|
||||||
resultDisplay: t.response.resultDisplay,
|
|
||||||
status: mapStatus(t.status),
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
renderOutputAsMarkdown: t.tool.isOutputMarkdown,
|
|
||||||
};
|
|
||||||
case 'awaiting_approval':
|
|
||||||
return {
|
|
||||||
callId: t.request.callId,
|
|
||||||
name: t.tool.displayName,
|
|
||||||
description: t.tool.getDescription(t.request.args),
|
|
||||||
resultDisplay: undefined,
|
|
||||||
status: mapStatus(t.status),
|
|
||||||
confirmationDetails: t.confirmationDetails,
|
|
||||||
renderOutputAsMarkdown: t.tool.isOutputMarkdown,
|
|
||||||
};
|
|
||||||
case 'executing':
|
|
||||||
return {
|
|
||||||
callId: t.request.callId,
|
|
||||||
name: t.tool.displayName,
|
|
||||||
description: t.tool.getDescription(t.request.args),
|
|
||||||
resultDisplay: t.liveOutput ?? undefined,
|
|
||||||
status: mapStatus(t.status),
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
renderOutputAsMarkdown: t.tool.isOutputMarkdown,
|
|
||||||
};
|
|
||||||
case 'validating': // Add this case
|
|
||||||
return {
|
|
||||||
callId: t.request.callId,
|
|
||||||
name: t.tool.displayName,
|
|
||||||
description: t.tool.getDescription(t.request.args),
|
|
||||||
resultDisplay: undefined,
|
|
||||||
status: mapStatus(t.status),
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
renderOutputAsMarkdown: t.tool.isOutputMarkdown,
|
|
||||||
};
|
|
||||||
case 'scheduled':
|
|
||||||
return {
|
|
||||||
callId: t.request.callId,
|
|
||||||
name: t.tool.displayName,
|
|
||||||
description: t.tool.getDescription(t.request.args),
|
|
||||||
resultDisplay: undefined,
|
|
||||||
status: mapStatus(t.status),
|
|
||||||
confirmationDetails: undefined,
|
|
||||||
renderOutputAsMarkdown: t.tool.isOutputMarkdown,
|
|
||||||
};
|
|
||||||
default: {
|
|
||||||
// ensures every case is checked for above
|
|
||||||
const exhaustiveCheck: never = t;
|
|
||||||
return exhaustiveCheck;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
const historyItem: HistoryItemToolGroup = {
|
|
||||||
type: 'tool_group',
|
|
||||||
tools: toolsDisplays,
|
|
||||||
};
|
|
||||||
return historyItem;
|
|
||||||
}
|
|
|
@ -0,0 +1,520 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
ToolCallRequestInfo,
|
||||||
|
ToolCallResponseInfo,
|
||||||
|
ToolConfirmationOutcome,
|
||||||
|
Tool,
|
||||||
|
ToolCallConfirmationDetails,
|
||||||
|
ToolResult,
|
||||||
|
ToolRegistry,
|
||||||
|
} from '../index.js';
|
||||||
|
import { Part, PartUnion, PartListUnion } from '@google/genai';
|
||||||
|
|
||||||
|
export type ValidatingToolCall = {
|
||||||
|
status: 'validating';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ScheduledToolCall = {
|
||||||
|
status: 'scheduled';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ErroredToolCall = {
|
||||||
|
status: 'error';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
response: ToolCallResponseInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type SuccessfulToolCall = {
|
||||||
|
status: 'success';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
response: ToolCallResponseInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ExecutingToolCall = {
|
||||||
|
status: 'executing';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
liveOutput?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type CancelledToolCall = {
|
||||||
|
status: 'cancelled';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
response: ToolCallResponseInfo;
|
||||||
|
tool: Tool;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type WaitingToolCall = {
|
||||||
|
status: 'awaiting_approval';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
confirmationDetails: ToolCallConfirmationDetails;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type Status = ToolCall['status'];
|
||||||
|
|
||||||
|
export type ToolCall =
|
||||||
|
| ValidatingToolCall
|
||||||
|
| ScheduledToolCall
|
||||||
|
| ErroredToolCall
|
||||||
|
| SuccessfulToolCall
|
||||||
|
| ExecutingToolCall
|
||||||
|
| CancelledToolCall
|
||||||
|
| WaitingToolCall;
|
||||||
|
|
||||||
|
export type CompletedToolCall =
|
||||||
|
| SuccessfulToolCall
|
||||||
|
| CancelledToolCall
|
||||||
|
| ErroredToolCall;
|
||||||
|
|
||||||
|
export type ConfirmHandler = (
|
||||||
|
toolCall: WaitingToolCall,
|
||||||
|
) => Promise<ToolConfirmationOutcome>;
|
||||||
|
|
||||||
|
export type OutputUpdateHandler = (
|
||||||
|
toolCallId: string,
|
||||||
|
outputChunk: string,
|
||||||
|
) => void;
|
||||||
|
|
||||||
|
export type AllToolCallsCompleteHandler = (
|
||||||
|
completedToolCalls: CompletedToolCall[],
|
||||||
|
) => void;
|
||||||
|
|
||||||
|
export type ToolCallsUpdateHandler = (toolCalls: ToolCall[]) => void;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Formats tool output for a Gemini FunctionResponse.
|
||||||
|
*/
|
||||||
|
export function formatLlmContentForFunctionResponse(
|
||||||
|
llmContent: PartListUnion,
|
||||||
|
): {
|
||||||
|
functionResponseJson: Record<string, string>;
|
||||||
|
additionalParts: PartUnion[];
|
||||||
|
} {
|
||||||
|
const additionalParts: PartUnion[] = [];
|
||||||
|
let functionResponseJson: Record<string, string>;
|
||||||
|
|
||||||
|
const contentToProcess =
|
||||||
|
Array.isArray(llmContent) && llmContent.length === 1
|
||||||
|
? llmContent[0]
|
||||||
|
: llmContent;
|
||||||
|
|
||||||
|
if (typeof contentToProcess === 'string') {
|
||||||
|
functionResponseJson = { output: contentToProcess };
|
||||||
|
} else if (Array.isArray(contentToProcess)) {
|
||||||
|
functionResponseJson = {
|
||||||
|
status: 'Tool execution succeeded.',
|
||||||
|
};
|
||||||
|
additionalParts.push(...contentToProcess);
|
||||||
|
} else if (contentToProcess.inlineData || contentToProcess.fileData) {
|
||||||
|
const mimeType =
|
||||||
|
contentToProcess.inlineData?.mimeType ||
|
||||||
|
contentToProcess.fileData?.mimeType ||
|
||||||
|
'unknown';
|
||||||
|
functionResponseJson = {
|
||||||
|
status: `Binary content of type ${mimeType} was processed.`,
|
||||||
|
};
|
||||||
|
additionalParts.push(contentToProcess);
|
||||||
|
} else if (contentToProcess.text !== undefined) {
|
||||||
|
functionResponseJson = { output: contentToProcess.text };
|
||||||
|
} else {
|
||||||
|
functionResponseJson = { status: 'Tool execution succeeded.' };
|
||||||
|
additionalParts.push(contentToProcess);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
functionResponseJson,
|
||||||
|
additionalParts,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const createErrorResponse = (
|
||||||
|
request: ToolCallRequestInfo,
|
||||||
|
error: Error,
|
||||||
|
): ToolCallResponseInfo => ({
|
||||||
|
callId: request.callId,
|
||||||
|
error,
|
||||||
|
responseParts: {
|
||||||
|
functionResponse: {
|
||||||
|
id: request.callId,
|
||||||
|
name: request.name,
|
||||||
|
response: { error: error.message },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resultDisplay: error.message,
|
||||||
|
});
|
||||||
|
|
||||||
|
interface CoreToolSchedulerOptions {
|
||||||
|
toolRegistry: ToolRegistry;
|
||||||
|
outputUpdateHandler?: OutputUpdateHandler;
|
||||||
|
onAllToolCallsComplete?: AllToolCallsCompleteHandler;
|
||||||
|
onToolCallsUpdate?: ToolCallsUpdateHandler;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class CoreToolScheduler {
|
||||||
|
private toolRegistry: ToolRegistry;
|
||||||
|
private toolCalls: ToolCall[] = [];
|
||||||
|
private abortController: AbortController;
|
||||||
|
private outputUpdateHandler?: OutputUpdateHandler;
|
||||||
|
private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
|
||||||
|
private onToolCallsUpdate?: ToolCallsUpdateHandler;
|
||||||
|
|
||||||
|
constructor(options: CoreToolSchedulerOptions) {
|
||||||
|
this.toolRegistry = options.toolRegistry;
|
||||||
|
this.outputUpdateHandler = options.outputUpdateHandler;
|
||||||
|
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
|
||||||
|
this.onToolCallsUpdate = options.onToolCallsUpdate;
|
||||||
|
this.abortController = new AbortController();
|
||||||
|
}
|
||||||
|
|
||||||
|
private setStatusInternal(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'success',
|
||||||
|
response: ToolCallResponseInfo,
|
||||||
|
): void;
|
||||||
|
private setStatusInternal(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'awaiting_approval',
|
||||||
|
confirmationDetails: ToolCallConfirmationDetails,
|
||||||
|
): void;
|
||||||
|
private setStatusInternal(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'error',
|
||||||
|
response: ToolCallResponseInfo,
|
||||||
|
): void;
|
||||||
|
private setStatusInternal(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'cancelled',
|
||||||
|
reason: string,
|
||||||
|
): void;
|
||||||
|
private setStatusInternal(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'executing' | 'scheduled' | 'validating',
|
||||||
|
): void;
|
||||||
|
private setStatusInternal(
|
||||||
|
targetCallId: string,
|
||||||
|
newStatus: Status,
|
||||||
|
auxiliaryData?: unknown,
|
||||||
|
): void {
|
||||||
|
this.toolCalls = this.toolCalls.map((currentCall) => {
|
||||||
|
if (
|
||||||
|
currentCall.request.callId !== targetCallId ||
|
||||||
|
currentCall.status === 'error'
|
||||||
|
) {
|
||||||
|
return currentCall;
|
||||||
|
}
|
||||||
|
|
||||||
|
const callWithToolContext = currentCall as ToolCall & { tool: Tool };
|
||||||
|
|
||||||
|
switch (newStatus) {
|
||||||
|
case 'success':
|
||||||
|
return {
|
||||||
|
...callWithToolContext,
|
||||||
|
status: 'success',
|
||||||
|
response: auxiliaryData as ToolCallResponseInfo,
|
||||||
|
} as SuccessfulToolCall;
|
||||||
|
case 'error':
|
||||||
|
return {
|
||||||
|
request: currentCall.request,
|
||||||
|
status: 'error',
|
||||||
|
response: auxiliaryData as ToolCallResponseInfo,
|
||||||
|
} as ErroredToolCall;
|
||||||
|
case 'awaiting_approval':
|
||||||
|
return {
|
||||||
|
...callWithToolContext,
|
||||||
|
status: 'awaiting_approval',
|
||||||
|
confirmationDetails: auxiliaryData as ToolCallConfirmationDetails,
|
||||||
|
} as WaitingToolCall;
|
||||||
|
case 'scheduled':
|
||||||
|
return {
|
||||||
|
...callWithToolContext,
|
||||||
|
status: 'scheduled',
|
||||||
|
} as ScheduledToolCall;
|
||||||
|
case 'cancelled':
|
||||||
|
return {
|
||||||
|
...callWithToolContext,
|
||||||
|
status: 'cancelled',
|
||||||
|
response: {
|
||||||
|
callId: currentCall.request.callId,
|
||||||
|
responseParts: {
|
||||||
|
functionResponse: {
|
||||||
|
id: currentCall.request.callId,
|
||||||
|
name: currentCall.request.name,
|
||||||
|
response: {
|
||||||
|
error: `[Operation Cancelled] Reason: ${auxiliaryData}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resultDisplay: undefined,
|
||||||
|
error: undefined,
|
||||||
|
},
|
||||||
|
} as CancelledToolCall;
|
||||||
|
case 'validating':
|
||||||
|
return {
|
||||||
|
...(currentCall as ValidatingToolCall),
|
||||||
|
status: 'validating',
|
||||||
|
} as ValidatingToolCall;
|
||||||
|
case 'executing':
|
||||||
|
return {
|
||||||
|
...callWithToolContext,
|
||||||
|
status: 'executing',
|
||||||
|
} as ExecutingToolCall;
|
||||||
|
default: {
|
||||||
|
const exhaustiveCheck: never = newStatus;
|
||||||
|
return exhaustiveCheck;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
this.notifyToolCallsUpdate();
|
||||||
|
this.checkAndNotifyCompletion();
|
||||||
|
}
|
||||||
|
|
||||||
|
private isRunning(): boolean {
|
||||||
|
return this.toolCalls.some(
|
||||||
|
(call) =>
|
||||||
|
call.status === 'executing' || call.status === 'awaiting_approval',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
async schedule(
|
||||||
|
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||||
|
): Promise<void> {
|
||||||
|
if (this.isRunning()) {
|
||||||
|
throw new Error(
|
||||||
|
'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const requestsToProcess = Array.isArray(request) ? request : [request];
|
||||||
|
|
||||||
|
const newToolCalls: ToolCall[] = requestsToProcess.map(
|
||||||
|
(reqInfo): ToolCall => {
|
||||||
|
const toolInstance = this.toolRegistry.getTool(reqInfo.name);
|
||||||
|
if (!toolInstance) {
|
||||||
|
return {
|
||||||
|
status: 'error',
|
||||||
|
request: reqInfo,
|
||||||
|
response: createErrorResponse(
|
||||||
|
reqInfo,
|
||||||
|
new Error(`Tool "${reqInfo.name}" not found in registry.`),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return { status: 'validating', request: reqInfo, tool: toolInstance };
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
this.toolCalls = this.toolCalls.concat(newToolCalls);
|
||||||
|
this.notifyToolCallsUpdate();
|
||||||
|
|
||||||
|
for (const toolCall of newToolCalls) {
|
||||||
|
if (toolCall.status !== 'validating') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { request: reqInfo, tool: toolInstance } = toolCall;
|
||||||
|
try {
|
||||||
|
const confirmationDetails = await toolInstance.shouldConfirmExecute(
|
||||||
|
reqInfo.args,
|
||||||
|
this.abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (confirmationDetails) {
|
||||||
|
const originalOnConfirm = confirmationDetails.onConfirm;
|
||||||
|
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
|
||||||
|
...confirmationDetails,
|
||||||
|
onConfirm: (outcome: ToolConfirmationOutcome) =>
|
||||||
|
this.handleConfirmationResponse(
|
||||||
|
reqInfo.callId,
|
||||||
|
originalOnConfirm,
|
||||||
|
outcome,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
this.setStatusInternal(
|
||||||
|
reqInfo.callId,
|
||||||
|
'awaiting_approval',
|
||||||
|
wrappedConfirmationDetails,
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
this.setStatusInternal(
|
||||||
|
reqInfo.callId,
|
||||||
|
'error',
|
||||||
|
createErrorResponse(
|
||||||
|
reqInfo,
|
||||||
|
error instanceof Error ? error : new Error(String(error)),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.attemptExecutionOfScheduledCalls();
|
||||||
|
this.checkAndNotifyCompletion();
|
||||||
|
}
|
||||||
|
|
||||||
|
async handleConfirmationResponse(
|
||||||
|
callId: string,
|
||||||
|
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
|
||||||
|
outcome: ToolConfirmationOutcome,
|
||||||
|
): Promise<void> {
|
||||||
|
const toolCall = this.toolCalls.find(
|
||||||
|
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
||||||
|
);
|
||||||
|
|
||||||
|
if (toolCall && toolCall.status === 'awaiting_approval') {
|
||||||
|
await originalOnConfirm(outcome);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outcome === ToolConfirmationOutcome.Cancel) {
|
||||||
|
this.setStatusInternal(
|
||||||
|
callId,
|
||||||
|
'cancelled',
|
||||||
|
'User did not allow tool call',
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
this.setStatusInternal(callId, 'scheduled');
|
||||||
|
}
|
||||||
|
this.attemptExecutionOfScheduledCalls();
|
||||||
|
}
|
||||||
|
|
||||||
|
private attemptExecutionOfScheduledCalls(): void {
|
||||||
|
const allCallsFinalOrScheduled = this.toolCalls.every(
|
||||||
|
(call) =>
|
||||||
|
call.status === 'scheduled' ||
|
||||||
|
call.status === 'cancelled' ||
|
||||||
|
call.status === 'success' ||
|
||||||
|
call.status === 'error',
|
||||||
|
);
|
||||||
|
|
||||||
|
if (allCallsFinalOrScheduled) {
|
||||||
|
const callsToExecute = this.toolCalls.filter(
|
||||||
|
(call) => call.status === 'scheduled',
|
||||||
|
);
|
||||||
|
|
||||||
|
callsToExecute.forEach((toolCall) => {
|
||||||
|
if (toolCall.status !== 'scheduled') return;
|
||||||
|
|
||||||
|
const scheduledCall = toolCall as ScheduledToolCall;
|
||||||
|
const { callId, name: toolName } = scheduledCall.request;
|
||||||
|
this.setStatusInternal(callId, 'executing');
|
||||||
|
|
||||||
|
const liveOutputCallback =
|
||||||
|
scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler
|
||||||
|
? (outputChunk: string) => {
|
||||||
|
if (this.outputUpdateHandler) {
|
||||||
|
this.outputUpdateHandler(callId, outputChunk);
|
||||||
|
}
|
||||||
|
this.toolCalls = this.toolCalls.map((tc) =>
|
||||||
|
tc.request.callId === callId && tc.status === 'executing'
|
||||||
|
? { ...(tc as ExecutingToolCall), liveOutput: outputChunk }
|
||||||
|
: tc,
|
||||||
|
);
|
||||||
|
this.notifyToolCallsUpdate();
|
||||||
|
}
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
scheduledCall.tool
|
||||||
|
.execute(
|
||||||
|
scheduledCall.request.args,
|
||||||
|
this.abortController.signal,
|
||||||
|
liveOutputCallback,
|
||||||
|
)
|
||||||
|
.then((toolResult: ToolResult) => {
|
||||||
|
if (this.abortController.signal.aborted) {
|
||||||
|
this.setStatusInternal(
|
||||||
|
callId,
|
||||||
|
'cancelled',
|
||||||
|
this.abortController.signal.reason || 'Execution aborted.',
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { functionResponseJson, additionalParts } =
|
||||||
|
formatLlmContentForFunctionResponse(toolResult.llmContent);
|
||||||
|
|
||||||
|
const functionResponsePart: Part = {
|
||||||
|
functionResponse: {
|
||||||
|
name: toolName,
|
||||||
|
id: callId,
|
||||||
|
response: functionResponseJson,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const successResponse: ToolCallResponseInfo = {
|
||||||
|
callId,
|
||||||
|
responseParts: [functionResponsePart, ...additionalParts],
|
||||||
|
resultDisplay: toolResult.returnDisplay,
|
||||||
|
error: undefined,
|
||||||
|
};
|
||||||
|
this.setStatusInternal(callId, 'success', successResponse);
|
||||||
|
})
|
||||||
|
.catch((executionError: Error) => {
|
||||||
|
this.setStatusInternal(
|
||||||
|
callId,
|
||||||
|
'error',
|
||||||
|
createErrorResponse(
|
||||||
|
scheduledCall.request,
|
||||||
|
executionError instanceof Error
|
||||||
|
? executionError
|
||||||
|
: new Error(String(executionError)),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private checkAndNotifyCompletion(): void {
|
||||||
|
const allCallsAreTerminal = this.toolCalls.every(
|
||||||
|
(call) =>
|
||||||
|
call.status === 'success' ||
|
||||||
|
call.status === 'error' ||
|
||||||
|
call.status === 'cancelled',
|
||||||
|
);
|
||||||
|
|
||||||
|
if (this.toolCalls.length > 0 && allCallsAreTerminal) {
|
||||||
|
const completedCalls = [...this.toolCalls] as CompletedToolCall[];
|
||||||
|
this.toolCalls = [];
|
||||||
|
|
||||||
|
if (this.onAllToolCallsComplete) {
|
||||||
|
this.onAllToolCallsComplete(completedCalls);
|
||||||
|
}
|
||||||
|
this.abortController = new AbortController();
|
||||||
|
this.notifyToolCallsUpdate();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cancelAll(reason: string = 'User initiated cancellation.'): void {
|
||||||
|
if (!this.abortController.signal.aborted) {
|
||||||
|
this.abortController.abort(reason);
|
||||||
|
}
|
||||||
|
this.abortController = new AbortController();
|
||||||
|
|
||||||
|
const callsToCancel = [...this.toolCalls];
|
||||||
|
callsToCancel.forEach((call) => {
|
||||||
|
if (
|
||||||
|
call.status !== 'error' &&
|
||||||
|
call.status !== 'success' &&
|
||||||
|
call.status !== 'cancelled'
|
||||||
|
) {
|
||||||
|
this.setStatusInternal(call.request.callId, 'cancelled', reason);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private notifyToolCallsUpdate(): void {
|
||||||
|
if (this.onToolCallsUpdate) {
|
||||||
|
this.onToolCallsUpdate([...this.toolCalls]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -13,8 +13,7 @@ export * from './core/logger.js';
|
||||||
export * from './core/prompts.js';
|
export * from './core/prompts.js';
|
||||||
export * from './core/turn.js';
|
export * from './core/turn.js';
|
||||||
export * from './core/geminiRequest.js';
|
export * from './core/geminiRequest.js';
|
||||||
// Potentially export types from turn.ts if needed externally
|
export * from './core/coreToolScheduler.js';
|
||||||
// export { GeminiEventType } from './core/turn.js'; // Example
|
|
||||||
|
|
||||||
// Export utilities
|
// Export utilities
|
||||||
export * from './utils/paths.js';
|
export * from './utils/paths.js';
|
||||||
|
|
|
@ -218,7 +218,7 @@ export interface ToolMcpConfirmationDetails {
|
||||||
serverName: string;
|
serverName: string;
|
||||||
toolName: string;
|
toolName: string;
|
||||||
toolDisplayName: string;
|
toolDisplayName: string;
|
||||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void> | void;
|
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ToolCallConfirmationDetails =
|
export type ToolCallConfirmationDetails =
|
||||||
|
|
Loading…
Reference in New Issue