feat: create tool scheduler hook (#468)

This commit is contained in:
Brandon Keiji 2025-05-21 17:35:40 +00:00 committed by GitHub
parent 2ad666a484
commit e1a64b41e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 464 additions and 0 deletions

View File

@ -0,0 +1,464 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
Config,
ToolCallRequestInfo,
ToolCallResponseInfo,
ToolConfirmationOutcome,
Tool,
} from '@gemini-code/server';
import { Part } from '@google/genai';
import { useCallback, useEffect, useState } from 'react';
import {
HistoryItemToolGroup,
IndividualToolCallDisplay,
ToolCallStatus,
} from '../types.js';
type ScheduledToolCall = {
status: 'scheduled';
request: ToolCallRequestInfo;
tool: Tool;
};
type ErroredToolCall = {
status: 'error';
request: ToolCallRequestInfo;
response: ToolCallResponseInfo;
};
type SuccessfulToolCall = {
status: 'success';
request: ToolCallRequestInfo;
tool: Tool;
response: ToolCallResponseInfo;
};
type ExecutingToolCall = {
status: 'executing';
request: ToolCallRequestInfo;
tool: Tool;
};
type CancelledToolCall = {
status: 'cancelled';
request: ToolCallRequestInfo;
response: ToolCallResponseInfo;
tool: Tool;
};
type WaitingToolCall = {
status: 'awaiting_approval';
request: ToolCallRequestInfo;
tool: Tool;
confirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
};
export type Status = ToolCall['status'];
export type ToolCall =
| ScheduledToolCall
| ErroredToolCall
| SuccessfulToolCall
| ExecutingToolCall
| CancelledToolCall
| WaitingToolCall;
export type ScheduleFn = (
request: ToolCallRequestInfo | ToolCallRequestInfo[],
) => void;
export type CancelFn = () => void;
export type CompletedToolCall =
| SuccessfulToolCall
| CancelledToolCall
| ErroredToolCall;
export function useToolScheduler(
onComplete: (tools: CompletedToolCall[]) => void,
config: Config,
): [ToolCall[], ScheduleFn, CancelFn] {
const [toolRegistry] = useState(() => config.getToolRegistry());
const [toolCalls, setToolCalls] = useState<ToolCall[]>([]);
const [abortController, setAbortController] = useState<AbortController>(
() => new AbortController(),
);
const isRunning = toolCalls.some(
(t) => t.status === 'executing' || t.status === 'awaiting_approval',
);
// Note: request array[] typically signal pending tool calls
const schedule = useCallback(
async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => {
if (isRunning) {
throw new Error(
'Cannot schedule tool calls while other tool calls are running',
);
}
const requests = Array.isArray(request) ? request : [request];
const newCalls: ToolCall[] = await Promise.all(
requests.map(async (r): Promise<ToolCall> => {
const tool = toolRegistry.getTool(r.name);
if (!tool) {
return {
status: 'error',
request: r,
response: toolErrorResponse(
r,
new Error(`tool ${r.name} does not exist`),
),
};
}
const userApproval = await tool.shouldConfirmExecute(r.args);
if (userApproval) {
return {
status: 'awaiting_approval',
request: r,
tool,
confirm: async (outcome) => {
await userApproval.onConfirm(outcome);
setToolCalls(
outcome === ToolConfirmationOutcome.Cancel
? setStatus(
r.callId,
'cancelled',
'User did not allow tool call',
)
: setStatus(r.callId, 'scheduled'),
);
},
};
}
return {
status: 'scheduled',
request: r,
tool,
};
}),
);
setToolCalls((t) => t.concat(newCalls));
},
[isRunning, setToolCalls, toolRegistry],
);
const cancel = useCallback(
(reason: string = 'unspecified') => {
abortController.abort();
setAbortController(new AbortController());
setToolCalls((tc) =>
tc.map((c) =>
c.status !== 'error'
? {
...c,
status: 'cancelled',
response: {
callId: c.request.callId,
responsePart: {
functionResponse: {
id: c.request.callId,
name: c.request.name,
response: {
error: `[Operation Cancelled] Reason: ${reason}`,
},
},
},
resultDisplay: undefined,
error: undefined,
},
}
: c,
),
);
},
[abortController],
);
useEffect(() => {
// effect for executing scheduled tool calls
const scheduledCalls = toolCalls.filter((t) => t.status === 'scheduled');
const awaitingConfirmation = toolCalls.some(
(t) => t.status === 'awaiting_approval',
);
if (!awaitingConfirmation && scheduledCalls.length) {
scheduledCalls.forEach(async (c) => {
const callId = c.request.callId;
try {
setToolCalls(setStatus(c.request.callId, 'executing'));
const result = await c.tool.execute(
c.request.args,
abortController.signal,
);
const functionResponse: Part = {
functionResponse: {
name: c.request.name,
id: callId,
response: { output: result.llmContent },
},
};
const response: ToolCallResponseInfo = {
callId,
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
setToolCalls(setStatus(callId, 'success', response));
} catch (e: unknown) {
setToolCalls(
setStatus(
callId,
'error',
toolErrorResponse(
c.request,
e instanceof Error ? e : new Error(String(e)),
),
),
);
}
});
}
}, [toolCalls, toolRegistry, abortController.signal]);
useEffect(() => {
const completedTools = toolCalls.filter(
(t) =>
t.status === 'success' ||
t.status === 'error' ||
t.status === 'cancelled',
);
const allDone = completedTools.length === toolCalls.length;
if (toolCalls.length && allDone) {
onComplete(completedTools);
setToolCalls([]);
setAbortController(new AbortController());
}
}, [toolCalls, onComplete]);
return [toolCalls, schedule, cancel];
}
function setStatus(
targetCallId: string,
status: 'success',
response: ToolCallResponseInfo,
): (t: ToolCall[]) => ToolCall[];
function setStatus(
targetCallId: string,
status: 'awaiting_approval',
confirm: (t: ToolConfirmationOutcome) => Promise<void>,
): (t: ToolCall[]) => ToolCall[];
function setStatus(
targetCallId: string,
status: 'error',
response: ToolCallResponseInfo,
): (t: ToolCall[]) => ToolCall[];
function setStatus(
targetCallId: string,
status: 'cancelled',
reason: string,
): (t: ToolCall[]) => ToolCall[];
function setStatus(
targetCallId: string,
status: 'executing' | 'scheduled',
): (t: ToolCall[]) => ToolCall[];
function setStatus(
targetCallId: string,
status: Status,
auxiliaryData?: unknown,
): (t: ToolCall[]) => ToolCall[] {
return function (tc: ToolCall[]): ToolCall[] {
return tc.map((t) => {
if (t.request.callId !== targetCallId || t.status === 'error') {
return t;
}
switch (status) {
case 'success': {
const next: SuccessfulToolCall = {
...t,
status: 'success',
response: auxiliaryData as ToolCallResponseInfo,
};
return next;
}
case 'error': {
const next: ErroredToolCall = {
...t,
status: 'error',
response: auxiliaryData as ToolCallResponseInfo,
};
return next;
}
case 'awaiting_approval': {
const next: WaitingToolCall = {
...t,
status: 'awaiting_approval',
confirm: auxiliaryData as (
o: ToolConfirmationOutcome,
) => Promise<void>,
};
return next;
}
case 'scheduled': {
const next: ScheduledToolCall = {
...t,
status: 'scheduled',
};
return next;
}
case 'cancelled': {
const next: CancelledToolCall = {
...t,
status: 'cancelled',
response: {
callId: t.request.callId,
responsePart: {
functionResponse: {
id: t.request.callId,
name: t.request.name,
response: {
error: `[Operation Cancelled] Reason: ${auxiliaryData}`,
},
},
},
resultDisplay: undefined,
error: undefined,
},
};
return next;
}
case 'executing': {
const next: ExecutingToolCall = {
...t,
status: 'executing',
};
return next;
}
default: {
// ensures every case is checked for above
const exhaustiveCheck: never = status;
return exhaustiveCheck;
}
}
});
};
}
const toolErrorResponse = (
request: ToolCallRequestInfo,
error: Error,
): ToolCallResponseInfo => ({
callId: request.callId,
error,
responsePart: {
functionResponse: {
id: request.callId,
name: request.name,
response: { error: error.message },
},
},
resultDisplay: error.message,
});
function mapStatus(status: Status): ToolCallStatus {
switch (status) {
case 'awaiting_approval':
return ToolCallStatus.Confirming;
case 'executing':
return ToolCallStatus.Executing;
case 'success':
return ToolCallStatus.Success;
case 'cancelled':
return ToolCallStatus.Canceled;
case 'error':
return ToolCallStatus.Error;
case 'scheduled':
return ToolCallStatus.Pending;
default: {
// ensures every case is checked for above
const exhaustiveCheck: never = status;
return exhaustiveCheck;
}
}
}
// convenient function for callers to map ToolCall back to a HistoryItem
export function mapToDisplay(
tool: ToolCall[] | ToolCall,
): HistoryItemToolGroup {
const tools = Array.isArray(tool) ? tool : [tool];
const toolsDisplays = tools.map((t): IndividualToolCallDisplay => {
switch (t.status) {
case 'success':
return {
callId: t.request.callId,
name: t.tool.displayName,
description: t.tool.getDescription(t.request.args),
resultDisplay: t.response.resultDisplay,
status: mapStatus(t.status),
confirmationDetails: undefined,
};
case 'error':
return {
callId: t.request.callId,
name: t.request.name,
description: '',
resultDisplay: t.response.resultDisplay,
status: mapStatus(t.status),
confirmationDetails: undefined,
};
case 'cancelled':
return {
callId: t.request.callId,
name: t.tool.displayName,
description: t.tool.getDescription(t.request.args),
resultDisplay: t.response.resultDisplay,
status: mapStatus(t.status),
confirmationDetails: undefined,
};
case 'awaiting_approval':
return {
callId: t.request.callId,
name: t.tool.displayName,
description: t.tool.getDescription(t.request.args),
resultDisplay: undefined,
status: mapStatus(t.status),
confirmationDetails: {
title: t.request.name,
onConfirm: t.confirm,
},
};
case 'executing':
return {
callId: t.request.callId,
name: t.tool.displayName,
description: t.tool.getDescription(t.request.args),
resultDisplay: undefined,
status: mapStatus(t.status),
confirmationDetails: undefined,
};
case 'scheduled':
return {
callId: t.request.callId,
name: t.tool.displayName,
description: t.tool.getDescription(t.request.args),
resultDisplay: undefined,
status: mapStatus(t.status),
confirmationDetails: undefined,
};
default: {
// ensures every case is checked for above
const exhaustiveCheck: never = t;
return exhaustiveCheck;
}
}
});
const historyItem: HistoryItemToolGroup = {
type: 'tool_group',
tools: toolsDisplays,
};
return historyItem;
}