feat: create tool scheduler hook (#468)
This commit is contained in:
parent
2ad666a484
commit
e1a64b41e8
|
@ -0,0 +1,464 @@
|
|||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
Config,
|
||||
ToolCallRequestInfo,
|
||||
ToolCallResponseInfo,
|
||||
ToolConfirmationOutcome,
|
||||
Tool,
|
||||
} from '@gemini-code/server';
|
||||
import { Part } from '@google/genai';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import {
|
||||
HistoryItemToolGroup,
|
||||
IndividualToolCallDisplay,
|
||||
ToolCallStatus,
|
||||
} from '../types.js';
|
||||
|
||||
type ScheduledToolCall = {
|
||||
status: 'scheduled';
|
||||
request: ToolCallRequestInfo;
|
||||
tool: Tool;
|
||||
};
|
||||
|
||||
type ErroredToolCall = {
|
||||
status: 'error';
|
||||
request: ToolCallRequestInfo;
|
||||
response: ToolCallResponseInfo;
|
||||
};
|
||||
|
||||
type SuccessfulToolCall = {
|
||||
status: 'success';
|
||||
request: ToolCallRequestInfo;
|
||||
tool: Tool;
|
||||
response: ToolCallResponseInfo;
|
||||
};
|
||||
|
||||
type ExecutingToolCall = {
|
||||
status: 'executing';
|
||||
request: ToolCallRequestInfo;
|
||||
tool: Tool;
|
||||
};
|
||||
|
||||
type CancelledToolCall = {
|
||||
status: 'cancelled';
|
||||
request: ToolCallRequestInfo;
|
||||
response: ToolCallResponseInfo;
|
||||
tool: Tool;
|
||||
};
|
||||
|
||||
type WaitingToolCall = {
|
||||
status: 'awaiting_approval';
|
||||
request: ToolCallRequestInfo;
|
||||
tool: Tool;
|
||||
confirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
};
|
||||
|
||||
export type Status = ToolCall['status'];
|
||||
|
||||
export type ToolCall =
|
||||
| ScheduledToolCall
|
||||
| ErroredToolCall
|
||||
| SuccessfulToolCall
|
||||
| ExecutingToolCall
|
||||
| CancelledToolCall
|
||||
| WaitingToolCall;
|
||||
|
||||
export type ScheduleFn = (
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
) => void;
|
||||
export type CancelFn = () => void;
|
||||
export type CompletedToolCall =
|
||||
| SuccessfulToolCall
|
||||
| CancelledToolCall
|
||||
| ErroredToolCall;
|
||||
|
||||
export function useToolScheduler(
|
||||
onComplete: (tools: CompletedToolCall[]) => void,
|
||||
config: Config,
|
||||
): [ToolCall[], ScheduleFn, CancelFn] {
|
||||
const [toolRegistry] = useState(() => config.getToolRegistry());
|
||||
const [toolCalls, setToolCalls] = useState<ToolCall[]>([]);
|
||||
const [abortController, setAbortController] = useState<AbortController>(
|
||||
() => new AbortController(),
|
||||
);
|
||||
|
||||
const isRunning = toolCalls.some(
|
||||
(t) => t.status === 'executing' || t.status === 'awaiting_approval',
|
||||
);
|
||||
// Note: request array[] typically signal pending tool calls
|
||||
const schedule = useCallback(
|
||||
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
|
||||
if (isRunning) {
|
||||
throw new Error(
|
||||
'Cannot schedule tool calls while other tool calls are running',
|
||||
);
|
||||
}
|
||||
const requests = Array.isArray(request) ? request : [request];
|
||||
const newCalls: ToolCall[] = await Promise.all(
|
||||
requests.map(async (r): Promise<ToolCall> => {
|
||||
const tool = toolRegistry.getTool(r.name);
|
||||
if (!tool) {
|
||||
return {
|
||||
status: 'error',
|
||||
request: r,
|
||||
response: toolErrorResponse(
|
||||
r,
|
||||
new Error(`tool ${r.name} does not exist`),
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
const userApproval = await tool.shouldConfirmExecute(r.args);
|
||||
if (userApproval) {
|
||||
return {
|
||||
status: 'awaiting_approval',
|
||||
request: r,
|
||||
tool,
|
||||
confirm: async (outcome) => {
|
||||
await userApproval.onConfirm(outcome);
|
||||
setToolCalls(
|
||||
outcome === ToolConfirmationOutcome.Cancel
|
||||
? setStatus(
|
||||
r.callId,
|
||||
'cancelled',
|
||||
'User did not allow tool call',
|
||||
)
|
||||
: setStatus(r.callId, 'scheduled'),
|
||||
);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
status: 'scheduled',
|
||||
request: r,
|
||||
tool,
|
||||
};
|
||||
}),
|
||||
);
|
||||
setToolCalls((t) => t.concat(newCalls));
|
||||
},
|
||||
[isRunning, setToolCalls, toolRegistry],
|
||||
);
|
||||
|
||||
const cancel = useCallback(
|
||||
(reason: string = 'unspecified') => {
|
||||
abortController.abort();
|
||||
setAbortController(new AbortController());
|
||||
setToolCalls((tc) =>
|
||||
tc.map((c) =>
|
||||
c.status !== 'error'
|
||||
? {
|
||||
...c,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
callId: c.request.callId,
|
||||
responsePart: {
|
||||
functionResponse: {
|
||||
id: c.request.callId,
|
||||
name: c.request.name,
|
||||
response: {
|
||||
error: `[Operation Cancelled] Reason: ${reason}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
resultDisplay: undefined,
|
||||
error: undefined,
|
||||
},
|
||||
}
|
||||
: c,
|
||||
),
|
||||
);
|
||||
},
|
||||
[abortController],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// effect for executing scheduled tool calls
|
||||
const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled');
|
||||
const awaitingConfirmation = toolCalls.some(
|
||||
(t) => t.status === 'awaiting_approval',
|
||||
);
|
||||
if (!awaitingConfirmation && scheduledCalls.length) {
|
||||
scheduledCalls.forEach(async (c) => {
|
||||
const callId = c.request.callId;
|
||||
try {
|
||||
setToolCalls(setStatus(c.request.callId, 'executing'));
|
||||
const result = await c.tool.execute(
|
||||
c.request.args,
|
||||
abortController.signal,
|
||||
);
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
name: c.request.name,
|
||||
id: callId,
|
||||
response: { output: result.llmContent },
|
||||
},
|
||||
};
|
||||
const response: ToolCallResponseInfo = {
|
||||
callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay: result.returnDisplay,
|
||||
error: undefined,
|
||||
};
|
||||
setToolCalls(setStatus(callId, 'success', response));
|
||||
} catch (e: unknown) {
|
||||
setToolCalls(
|
||||
setStatus(
|
||||
callId,
|
||||
'error',
|
||||
toolErrorResponse(
|
||||
c.request,
|
||||
e instanceof Error ? e : new Error(String(e)),
|
||||
),
|
||||
),
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [toolCalls, toolRegistry, abortController.signal]);
|
||||
|
||||
useEffect(() => {
|
||||
const completedTools = toolCalls.filter(
|
||||
(t) =>
|
||||
t.status === 'success' ||
|
||||
t.status === 'error' ||
|
||||
t.status === 'cancelled',
|
||||
);
|
||||
const allDone = completedTools.length === toolCalls.length;
|
||||
if (toolCalls.length && allDone) {
|
||||
onComplete(completedTools);
|
||||
setToolCalls([]);
|
||||
setAbortController(new AbortController());
|
||||
}
|
||||
}, [toolCalls, onComplete]);
|
||||
|
||||
return [toolCalls, schedule, cancel];
|
||||
}
|
||||
|
||||
function setStatus(
|
||||
targetCallId: string,
|
||||
status: 'success',
|
||||
response: ToolCallResponseInfo,
|
||||
): (t: ToolCall[]) => ToolCall[];
|
||||
function setStatus(
|
||||
targetCallId: string,
|
||||
status: 'awaiting_approval',
|
||||
confirm: (t: ToolConfirmationOutcome) => Promise<void>,
|
||||
): (t: ToolCall[]) => ToolCall[];
|
||||
function setStatus(
|
||||
targetCallId: string,
|
||||
status: 'error',
|
||||
response: ToolCallResponseInfo,
|
||||
): (t: ToolCall[]) => ToolCall[];
|
||||
function setStatus(
|
||||
targetCallId: string,
|
||||
status: 'cancelled',
|
||||
reason: string,
|
||||
): (t: ToolCall[]) => ToolCall[];
|
||||
function setStatus(
|
||||
targetCallId: string,
|
||||
status: 'executing' | 'scheduled',
|
||||
): (t: ToolCall[]) => ToolCall[];
|
||||
function setStatus(
|
||||
targetCallId: string,
|
||||
status: Status,
|
||||
auxiliaryData?: unknown,
|
||||
): (t: ToolCall[]) => ToolCall[] {
|
||||
return function (tc: ToolCall[]): ToolCall[] {
|
||||
return tc.map((t) => {
|
||||
if (t.request.callId !== targetCallId || t.status === 'error') {
|
||||
return t;
|
||||
}
|
||||
switch (status) {
|
||||
case 'success': {
|
||||
const next: SuccessfulToolCall = {
|
||||
...t,
|
||||
status: 'success',
|
||||
response: auxiliaryData as ToolCallResponseInfo,
|
||||
};
|
||||
return next;
|
||||
}
|
||||
case 'error': {
|
||||
const next: ErroredToolCall = {
|
||||
...t,
|
||||
status: 'error',
|
||||
response: auxiliaryData as ToolCallResponseInfo,
|
||||
};
|
||||
return next;
|
||||
}
|
||||
case 'awaiting_approval': {
|
||||
const next: WaitingToolCall = {
|
||||
...t,
|
||||
status: 'awaiting_approval',
|
||||
confirm: auxiliaryData as (
|
||||
o: ToolConfirmationOutcome,
|
||||
) => Promise<void>,
|
||||
};
|
||||
return next;
|
||||
}
|
||||
case 'scheduled': {
|
||||
const next: ScheduledToolCall = {
|
||||
...t,
|
||||
status: 'scheduled',
|
||||
};
|
||||
return next;
|
||||
}
|
||||
case 'cancelled': {
|
||||
const next: CancelledToolCall = {
|
||||
...t,
|
||||
status: 'cancelled',
|
||||
response: {
|
||||
callId: t.request.callId,
|
||||
responsePart: {
|
||||
functionResponse: {
|
||||
id: t.request.callId,
|
||||
name: t.request.name,
|
||||
response: {
|
||||
error: `[Operation Cancelled] Reason: ${auxiliaryData}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
resultDisplay: undefined,
|
||||
error: undefined,
|
||||
},
|
||||
};
|
||||
return next;
|
||||
}
|
||||
case 'executing': {
|
||||
const next: ExecutingToolCall = {
|
||||
...t,
|
||||
status: 'executing',
|
||||
};
|
||||
return next;
|
||||
}
|
||||
default: {
|
||||
// ensures every case is checked for above
|
||||
const exhaustiveCheck: never = status;
|
||||
return exhaustiveCheck;
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
const toolErrorResponse = (
|
||||
request: ToolCallRequestInfo,
|
||||
error: Error,
|
||||
): ToolCallResponseInfo => ({
|
||||
callId: request.callId,
|
||||
error,
|
||||
responsePart: {
|
||||
functionResponse: {
|
||||
id: request.callId,
|
||||
name: request.name,
|
||||
response: { error: error.message },
|
||||
},
|
||||
},
|
||||
resultDisplay: error.message,
|
||||
});
|
||||
|
||||
function mapStatus(status: Status): ToolCallStatus {
|
||||
switch (status) {
|
||||
case 'awaiting_approval':
|
||||
return ToolCallStatus.Confirming;
|
||||
case 'executing':
|
||||
return ToolCallStatus.Executing;
|
||||
case 'success':
|
||||
return ToolCallStatus.Success;
|
||||
case 'cancelled':
|
||||
return ToolCallStatus.Canceled;
|
||||
case 'error':
|
||||
return ToolCallStatus.Error;
|
||||
case 'scheduled':
|
||||
return ToolCallStatus.Pending;
|
||||
default: {
|
||||
// ensures every case is checked for above
|
||||
const exhaustiveCheck: never = status;
|
||||
return exhaustiveCheck;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convenient function for callers to map ToolCall back to a HistoryItem
|
||||
export function mapToDisplay(
|
||||
tool: ToolCall[] | ToolCall,
|
||||
): HistoryItemToolGroup {
|
||||
const tools = Array.isArray(tool) ? tool : [tool];
|
||||
const toolsDisplays = tools.map((t): IndividualToolCallDisplay => {
|
||||
switch (t.status) {
|
||||
case 'success':
|
||||
return {
|
||||
callId: t.request.callId,
|
||||
name: t.tool.displayName,
|
||||
description: t.tool.getDescription(t.request.args),
|
||||
resultDisplay: t.response.resultDisplay,
|
||||
status: mapStatus(t.status),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
case 'error':
|
||||
return {
|
||||
callId: t.request.callId,
|
||||
name: t.request.name,
|
||||
description: '',
|
||||
resultDisplay: t.response.resultDisplay,
|
||||
status: mapStatus(t.status),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
case 'cancelled':
|
||||
return {
|
||||
callId: t.request.callId,
|
||||
name: t.tool.displayName,
|
||||
description: t.tool.getDescription(t.request.args),
|
||||
resultDisplay: t.response.resultDisplay,
|
||||
status: mapStatus(t.status),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
case 'awaiting_approval':
|
||||
return {
|
||||
callId: t.request.callId,
|
||||
name: t.tool.displayName,
|
||||
description: t.tool.getDescription(t.request.args),
|
||||
resultDisplay: undefined,
|
||||
status: mapStatus(t.status),
|
||||
confirmationDetails: {
|
||||
title: t.request.name,
|
||||
onConfirm: t.confirm,
|
||||
},
|
||||
};
|
||||
case 'executing':
|
||||
return {
|
||||
callId: t.request.callId,
|
||||
name: t.tool.displayName,
|
||||
description: t.tool.getDescription(t.request.args),
|
||||
resultDisplay: undefined,
|
||||
status: mapStatus(t.status),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
case 'scheduled':
|
||||
return {
|
||||
callId: t.request.callId,
|
||||
name: t.tool.displayName,
|
||||
description: t.tool.getDescription(t.request.args),
|
||||
resultDisplay: undefined,
|
||||
status: mapStatus(t.status),
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
default: {
|
||||
// ensures every case is checked for above
|
||||
const exhaustiveCheck: never = t;
|
||||
return exhaustiveCheck;
|
||||
}
|
||||
}
|
||||
});
|
||||
const historyItem: HistoryItemToolGroup = {
|
||||
type: 'tool_group',
|
||||
tools: toolsDisplays,
|
||||
};
|
||||
return historyItem;
|
||||
}
|
Loading…
Reference in New Issue