feat: create tool scheduler hook (#468)
This commit is contained in:
parent
2ad666a484
commit
e1a64b41e8
|
@ -0,0 +1,464 @@
|
||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
Config,
|
||||||
|
ToolCallRequestInfo,
|
||||||
|
ToolCallResponseInfo,
|
||||||
|
ToolConfirmationOutcome,
|
||||||
|
Tool,
|
||||||
|
} from '@gemini-code/server';
|
||||||
|
import { Part } from '@google/genai';
|
||||||
|
import { useCallback, useEffect, useState } from 'react';
|
||||||
|
import {
|
||||||
|
HistoryItemToolGroup,
|
||||||
|
IndividualToolCallDisplay,
|
||||||
|
ToolCallStatus,
|
||||||
|
} from '../types.js';
|
||||||
|
|
||||||
|
type ScheduledToolCall = {
|
||||||
|
status: 'scheduled';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
};
|
||||||
|
|
||||||
|
type ErroredToolCall = {
|
||||||
|
status: 'error';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
response: ToolCallResponseInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
type SuccessfulToolCall = {
|
||||||
|
status: 'success';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
response: ToolCallResponseInfo;
|
||||||
|
};
|
||||||
|
|
||||||
|
type ExecutingToolCall = {
|
||||||
|
status: 'executing';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
};
|
||||||
|
|
||||||
|
type CancelledToolCall = {
|
||||||
|
status: 'cancelled';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
response: ToolCallResponseInfo;
|
||||||
|
tool: Tool;
|
||||||
|
};
|
||||||
|
|
||||||
|
type WaitingToolCall = {
|
||||||
|
status: 'awaiting_approval';
|
||||||
|
request: ToolCallRequestInfo;
|
||||||
|
tool: Tool;
|
||||||
|
confirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type Status = ToolCall['status'];
|
||||||
|
|
||||||
|
export type ToolCall =
|
||||||
|
| ScheduledToolCall
|
||||||
|
| ErroredToolCall
|
||||||
|
| SuccessfulToolCall
|
||||||
|
| ExecutingToolCall
|
||||||
|
| CancelledToolCall
|
||||||
|
| WaitingToolCall;
|
||||||
|
|
||||||
|
export type ScheduleFn = (
|
||||||
|
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||||
|
) => void;
|
||||||
|
export type CancelFn = () => void;
|
||||||
|
export type CompletedToolCall =
|
||||||
|
| SuccessfulToolCall
|
||||||
|
| CancelledToolCall
|
||||||
|
| ErroredToolCall;
|
||||||
|
|
||||||
|
export function useToolScheduler(
|
||||||
|
onComplete: (tools: CompletedToolCall[]) => void,
|
||||||
|
config: Config,
|
||||||
|
): [ToolCall[], ScheduleFn, CancelFn] {
|
||||||
|
const [toolRegistry] = useState(() => config.getToolRegistry());
|
||||||
|
const [toolCalls, setToolCalls] = useState<ToolCall[]>([]);
|
||||||
|
const [abortController, setAbortController] = useState<AbortController>(
|
||||||
|
() => new AbortController(),
|
||||||
|
);
|
||||||
|
|
||||||
|
const isRunning = toolCalls.some(
|
||||||
|
(t) => t.status === 'executing' || t.status === 'awaiting_approval',
|
||||||
|
);
|
||||||
|
// Note: request array[] typically signal pending tool calls
|
||||||
|
const schedule = useCallback(
|
||||||
|
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
|
||||||
|
if (isRunning) {
|
||||||
|
throw new Error(
|
||||||
|
'Cannot schedule tool calls while other tool calls are running',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const requests = Array.isArray(request) ? request : [request];
|
||||||
|
const newCalls: ToolCall[] = await Promise.all(
|
||||||
|
requests.map(async (r): Promise<ToolCall> => {
|
||||||
|
const tool = toolRegistry.getTool(r.name);
|
||||||
|
if (!tool) {
|
||||||
|
return {
|
||||||
|
status: 'error',
|
||||||
|
request: r,
|
||||||
|
response: toolErrorResponse(
|
||||||
|
r,
|
||||||
|
new Error(`tool ${r.name} does not exist`),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const userApproval = await tool.shouldConfirmExecute(r.args);
|
||||||
|
if (userApproval) {
|
||||||
|
return {
|
||||||
|
status: 'awaiting_approval',
|
||||||
|
request: r,
|
||||||
|
tool,
|
||||||
|
confirm: async (outcome) => {
|
||||||
|
await userApproval.onConfirm(outcome);
|
||||||
|
setToolCalls(
|
||||||
|
outcome === ToolConfirmationOutcome.Cancel
|
||||||
|
? setStatus(
|
||||||
|
r.callId,
|
||||||
|
'cancelled',
|
||||||
|
'User did not allow tool call',
|
||||||
|
)
|
||||||
|
: setStatus(r.callId, 'scheduled'),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
status: 'scheduled',
|
||||||
|
request: r,
|
||||||
|
tool,
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
setToolCalls((t) => t.concat(newCalls));
|
||||||
|
},
|
||||||
|
[isRunning, setToolCalls, toolRegistry],
|
||||||
|
);
|
||||||
|
|
||||||
|
const cancel = useCallback(
|
||||||
|
(reason: string = 'unspecified') => {
|
||||||
|
abortController.abort();
|
||||||
|
setAbortController(new AbortController());
|
||||||
|
setToolCalls((tc) =>
|
||||||
|
tc.map((c) =>
|
||||||
|
c.status !== 'error'
|
||||||
|
? {
|
||||||
|
...c,
|
||||||
|
status: 'cancelled',
|
||||||
|
response: {
|
||||||
|
callId: c.request.callId,
|
||||||
|
responsePart: {
|
||||||
|
functionResponse: {
|
||||||
|
id: c.request.callId,
|
||||||
|
name: c.request.name,
|
||||||
|
response: {
|
||||||
|
error: `[Operation Cancelled] Reason: ${reason}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resultDisplay: undefined,
|
||||||
|
error: undefined,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: c,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[abortController],
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// effect for executing scheduled tool calls
|
||||||
|
const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled');
|
||||||
|
const awaitingConfirmation = toolCalls.some(
|
||||||
|
(t) => t.status === 'awaiting_approval',
|
||||||
|
);
|
||||||
|
if (!awaitingConfirmation && scheduledCalls.length) {
|
||||||
|
scheduledCalls.forEach(async (c) => {
|
||||||
|
const callId = c.request.callId;
|
||||||
|
try {
|
||||||
|
setToolCalls(setStatus(c.request.callId, 'executing'));
|
||||||
|
const result = await c.tool.execute(
|
||||||
|
c.request.args,
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
const functionResponse: Part = {
|
||||||
|
functionResponse: {
|
||||||
|
name: c.request.name,
|
||||||
|
id: callId,
|
||||||
|
response: { output: result.llmContent },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
const response: ToolCallResponseInfo = {
|
||||||
|
callId,
|
||||||
|
responsePart: functionResponse,
|
||||||
|
resultDisplay: result.returnDisplay,
|
||||||
|
error: undefined,
|
||||||
|
};
|
||||||
|
setToolCalls(setStatus(callId, 'success', response));
|
||||||
|
} catch (e: unknown) {
|
||||||
|
setToolCalls(
|
||||||
|
setStatus(
|
||||||
|
callId,
|
||||||
|
'error',
|
||||||
|
toolErrorResponse(
|
||||||
|
c.request,
|
||||||
|
e instanceof Error ? e : new Error(String(e)),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [toolCalls, toolRegistry, abortController.signal]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const completedTools = toolCalls.filter(
|
||||||
|
(t) =>
|
||||||
|
t.status === 'success' ||
|
||||||
|
t.status === 'error' ||
|
||||||
|
t.status === 'cancelled',
|
||||||
|
);
|
||||||
|
const allDone = completedTools.length === toolCalls.length;
|
||||||
|
if (toolCalls.length && allDone) {
|
||||||
|
onComplete(completedTools);
|
||||||
|
setToolCalls([]);
|
||||||
|
setAbortController(new AbortController());
|
||||||
|
}
|
||||||
|
}, [toolCalls, onComplete]);
|
||||||
|
|
||||||
|
return [toolCalls, schedule, cancel];
|
||||||
|
}
|
||||||
|
|
||||||
|
function setStatus(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'success',
|
||||||
|
response: ToolCallResponseInfo,
|
||||||
|
): (t: ToolCall[]) => ToolCall[];
|
||||||
|
function setStatus(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'awaiting_approval',
|
||||||
|
confirm: (t: ToolConfirmationOutcome) => Promise<void>,
|
||||||
|
): (t: ToolCall[]) => ToolCall[];
|
||||||
|
function setStatus(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'error',
|
||||||
|
response: ToolCallResponseInfo,
|
||||||
|
): (t: ToolCall[]) => ToolCall[];
|
||||||
|
function setStatus(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'cancelled',
|
||||||
|
reason: string,
|
||||||
|
): (t: ToolCall[]) => ToolCall[];
|
||||||
|
function setStatus(
|
||||||
|
targetCallId: string,
|
||||||
|
status: 'executing' | 'scheduled',
|
||||||
|
): (t: ToolCall[]) => ToolCall[];
|
||||||
|
function setStatus(
|
||||||
|
targetCallId: string,
|
||||||
|
status: Status,
|
||||||
|
auxiliaryData?: unknown,
|
||||||
|
): (t: ToolCall[]) => ToolCall[] {
|
||||||
|
return function (tc: ToolCall[]): ToolCall[] {
|
||||||
|
return tc.map((t) => {
|
||||||
|
if (t.request.callId !== targetCallId || t.status === 'error') {
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
switch (status) {
|
||||||
|
case 'success': {
|
||||||
|
const next: SuccessfulToolCall = {
|
||||||
|
...t,
|
||||||
|
status: 'success',
|
||||||
|
response: auxiliaryData as ToolCallResponseInfo,
|
||||||
|
};
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
case 'error': {
|
||||||
|
const next: ErroredToolCall = {
|
||||||
|
...t,
|
||||||
|
status: 'error',
|
||||||
|
response: auxiliaryData as ToolCallResponseInfo,
|
||||||
|
};
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
case 'awaiting_approval': {
|
||||||
|
const next: WaitingToolCall = {
|
||||||
|
...t,
|
||||||
|
status: 'awaiting_approval',
|
||||||
|
confirm: auxiliaryData as (
|
||||||
|
o: ToolConfirmationOutcome,
|
||||||
|
) => Promise<void>,
|
||||||
|
};
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
case 'scheduled': {
|
||||||
|
const next: ScheduledToolCall = {
|
||||||
|
...t,
|
||||||
|
status: 'scheduled',
|
||||||
|
};
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
case 'cancelled': {
|
||||||
|
const next: CancelledToolCall = {
|
||||||
|
...t,
|
||||||
|
status: 'cancelled',
|
||||||
|
response: {
|
||||||
|
callId: t.request.callId,
|
||||||
|
responsePart: {
|
||||||
|
functionResponse: {
|
||||||
|
id: t.request.callId,
|
||||||
|
name: t.request.name,
|
||||||
|
response: {
|
||||||
|
error: `[Operation Cancelled] Reason: ${auxiliaryData}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resultDisplay: undefined,
|
||||||
|
error: undefined,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
case 'executing': {
|
||||||
|
const next: ExecutingToolCall = {
|
||||||
|
...t,
|
||||||
|
status: 'executing',
|
||||||
|
};
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
// ensures every case is checked for above
|
||||||
|
const exhaustiveCheck: never = status;
|
||||||
|
return exhaustiveCheck;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolErrorResponse = (
|
||||||
|
request: ToolCallRequestInfo,
|
||||||
|
error: Error,
|
||||||
|
): ToolCallResponseInfo => ({
|
||||||
|
callId: request.callId,
|
||||||
|
error,
|
||||||
|
responsePart: {
|
||||||
|
functionResponse: {
|
||||||
|
id: request.callId,
|
||||||
|
name: request.name,
|
||||||
|
response: { error: error.message },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resultDisplay: error.message,
|
||||||
|
});
|
||||||
|
|
||||||
|
function mapStatus(status: Status): ToolCallStatus {
|
||||||
|
switch (status) {
|
||||||
|
case 'awaiting_approval':
|
||||||
|
return ToolCallStatus.Confirming;
|
||||||
|
case 'executing':
|
||||||
|
return ToolCallStatus.Executing;
|
||||||
|
case 'success':
|
||||||
|
return ToolCallStatus.Success;
|
||||||
|
case 'cancelled':
|
||||||
|
return ToolCallStatus.Canceled;
|
||||||
|
case 'error':
|
||||||
|
return ToolCallStatus.Error;
|
||||||
|
case 'scheduled':
|
||||||
|
return ToolCallStatus.Pending;
|
||||||
|
default: {
|
||||||
|
// ensures every case is checked for above
|
||||||
|
const exhaustiveCheck: never = status;
|
||||||
|
return exhaustiveCheck;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// convenient function for callers to map ToolCall back to a HistoryItem
|
||||||
|
export function mapToDisplay(
|
||||||
|
tool: ToolCall[] | ToolCall,
|
||||||
|
): HistoryItemToolGroup {
|
||||||
|
const tools = Array.isArray(tool) ? tool : [tool];
|
||||||
|
const toolsDisplays = tools.map((t): IndividualToolCallDisplay => {
|
||||||
|
switch (t.status) {
|
||||||
|
case 'success':
|
||||||
|
return {
|
||||||
|
callId: t.request.callId,
|
||||||
|
name: t.tool.displayName,
|
||||||
|
description: t.tool.getDescription(t.request.args),
|
||||||
|
resultDisplay: t.response.resultDisplay,
|
||||||
|
status: mapStatus(t.status),
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
case 'error':
|
||||||
|
return {
|
||||||
|
callId: t.request.callId,
|
||||||
|
name: t.request.name,
|
||||||
|
description: '',
|
||||||
|
resultDisplay: t.response.resultDisplay,
|
||||||
|
status: mapStatus(t.status),
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
case 'cancelled':
|
||||||
|
return {
|
||||||
|
callId: t.request.callId,
|
||||||
|
name: t.tool.displayName,
|
||||||
|
description: t.tool.getDescription(t.request.args),
|
||||||
|
resultDisplay: t.response.resultDisplay,
|
||||||
|
status: mapStatus(t.status),
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
case 'awaiting_approval':
|
||||||
|
return {
|
||||||
|
callId: t.request.callId,
|
||||||
|
name: t.tool.displayName,
|
||||||
|
description: t.tool.getDescription(t.request.args),
|
||||||
|
resultDisplay: undefined,
|
||||||
|
status: mapStatus(t.status),
|
||||||
|
confirmationDetails: {
|
||||||
|
title: t.request.name,
|
||||||
|
onConfirm: t.confirm,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
case 'executing':
|
||||||
|
return {
|
||||||
|
callId: t.request.callId,
|
||||||
|
name: t.tool.displayName,
|
||||||
|
description: t.tool.getDescription(t.request.args),
|
||||||
|
resultDisplay: undefined,
|
||||||
|
status: mapStatus(t.status),
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
case 'scheduled':
|
||||||
|
return {
|
||||||
|
callId: t.request.callId,
|
||||||
|
name: t.tool.displayName,
|
||||||
|
description: t.tool.getDescription(t.request.args),
|
||||||
|
resultDisplay: undefined,
|
||||||
|
status: mapStatus(t.status),
|
||||||
|
confirmationDetails: undefined,
|
||||||
|
};
|
||||||
|
default: {
|
||||||
|
// ensures every case is checked for above
|
||||||
|
const exhaustiveCheck: never = t;
|
||||||
|
return exhaustiveCheck;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
const historyItem: HistoryItemToolGroup = {
|
||||||
|
type: 'tool_group',
|
||||||
|
tools: toolsDisplays,
|
||||||
|
};
|
||||||
|
return historyItem;
|
||||||
|
}
|
Loading…
Reference in New Issue