Add a request queue to the tool scheduler (#5845)
This commit is contained in:
parent
9ac62565a0
commit
69322e12e4
|
@ -63,7 +63,7 @@ export type TrackedToolCall =
|
|||
| TrackedCancelledToolCall;
|
||||
|
||||
export function useReactToolScheduler(
|
||||
onComplete: (tools: CompletedToolCall[]) => void,
|
||||
onComplete: (tools: CompletedToolCall[]) => Promise<void>,
|
||||
config: Config,
|
||||
setPendingHistoryItem: React.Dispatch<
|
||||
React.SetStateAction<HistoryItemWithoutId | null>
|
||||
|
@ -106,8 +106,8 @@ export function useReactToolScheduler(
|
|||
);
|
||||
|
||||
const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = useCallback(
|
||||
(completedToolCalls) => {
|
||||
onComplete(completedToolCalls);
|
||||
async (completedToolCalls) => {
|
||||
await onComplete(completedToolCalls);
|
||||
},
|
||||
[onComplete],
|
||||
);
|
||||
|
@ -157,7 +157,7 @@ export function useReactToolScheduler(
|
|||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
) => {
|
||||
scheduler.schedule(request, signal);
|
||||
void scheduler.schedule(request, signal);
|
||||
},
|
||||
[scheduler],
|
||||
);
|
||||
|
|
|
@ -592,3 +592,195 @@ describe('CoreToolScheduler YOLO mode', () => {
|
|||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('CoreToolScheduler request queueing', () => {
|
||||
it('should queue a request if another is running', async () => {
|
||||
let resolveFirstCall: (result: ToolResult) => void;
|
||||
const firstCallPromise = new Promise<ToolResult>((resolve) => {
|
||||
resolveFirstCall = resolve;
|
||||
});
|
||||
|
||||
const mockTool = new MockTool();
|
||||
mockTool.executeFn.mockImplementation(() => firstCallPromise);
|
||||
const declarativeTool = mockTool;
|
||||
|
||||
const toolRegistry = {
|
||||
getTool: () => declarativeTool,
|
||||
getToolByName: () => declarativeTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {} as any,
|
||||
registerTool: () => {},
|
||||
getToolByDisplayName: () => declarativeTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
};
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
toolRegistry: Promise.resolve(toolRegistry as any),
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const request1 = {
|
||||
callId: '1',
|
||||
name: 'mockTool',
|
||||
args: { a: 1 },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
};
|
||||
const request2 = {
|
||||
callId: '2',
|
||||
name: 'mockTool',
|
||||
args: { b: 2 },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-2',
|
||||
};
|
||||
|
||||
// Schedule the first call, which will pause execution.
|
||||
scheduler.schedule([request1], abortController.signal);
|
||||
|
||||
// Wait for the first call to be in the 'executing' state.
|
||||
await vi.waitFor(() => {
|
||||
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
|
||||
expect(calls?.[0]?.status).toBe('executing');
|
||||
});
|
||||
|
||||
// Schedule the second call while the first is "running".
|
||||
const schedulePromise2 = scheduler.schedule(
|
||||
[request2],
|
||||
abortController.signal,
|
||||
);
|
||||
|
||||
// Ensure the second tool call hasn't been executed yet.
|
||||
expect(mockTool.executeFn).toHaveBeenCalledTimes(1);
|
||||
expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 });
|
||||
|
||||
// Complete the first tool call.
|
||||
resolveFirstCall!({
|
||||
llmContent: 'First call complete',
|
||||
returnDisplay: 'First call complete',
|
||||
});
|
||||
|
||||
// Wait for the second schedule promise to resolve.
|
||||
await schedulePromise2;
|
||||
|
||||
// Wait for the second call to be in the 'executing' state.
|
||||
await vi.waitFor(() => {
|
||||
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
|
||||
expect(calls?.[0]?.status).toBe('executing');
|
||||
});
|
||||
|
||||
// Now the second tool call should have been executed.
|
||||
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
|
||||
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
|
||||
|
||||
// Let the second call finish.
|
||||
const secondCallResult = {
|
||||
llmContent: 'Second call complete',
|
||||
returnDisplay: 'Second call complete',
|
||||
};
|
||||
// Since the mock is shared, we need to resolve the current promise.
|
||||
// In a real scenario, a new promise would be created for the second call.
|
||||
resolveFirstCall!(secondCallResult);
|
||||
|
||||
// Wait for the second completion.
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
// Verify the completion callbacks were called correctly.
|
||||
expect(onAllToolCallsComplete.mock.calls[0][0][0].status).toBe('success');
|
||||
expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success');
|
||||
});
|
||||
|
||||
it('should handle two synchronous calls to schedule', async () => {
|
||||
const mockTool = new MockTool();
|
||||
const declarativeTool = mockTool;
|
||||
const toolRegistry = {
|
||||
getTool: () => declarativeTool,
|
||||
getToolByName: () => declarativeTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {} as any,
|
||||
registerTool: () => {},
|
||||
getToolByDisplayName: () => declarativeTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
};
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.YOLO,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
toolRegistry: Promise.resolve(toolRegistry as any),
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const request1 = {
|
||||
callId: '1',
|
||||
name: 'mockTool',
|
||||
args: { a: 1 },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
};
|
||||
const request2 = {
|
||||
callId: '2',
|
||||
name: 'mockTool',
|
||||
args: { b: 2 },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-2',
|
||||
};
|
||||
|
||||
// Schedule two calls synchronously.
|
||||
const schedulePromise1 = scheduler.schedule(
|
||||
[request1],
|
||||
abortController.signal,
|
||||
);
|
||||
const schedulePromise2 = scheduler.schedule(
|
||||
[request2],
|
||||
abortController.signal,
|
||||
);
|
||||
|
||||
// Wait for both promises to resolve.
|
||||
await Promise.all([schedulePromise1, schedulePromise2]);
|
||||
|
||||
// Ensure the tool was called twice with the correct arguments.
|
||||
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
|
||||
expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 });
|
||||
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
|
||||
|
||||
// Ensure completion callbacks were called twice.
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -125,7 +125,7 @@ export type OutputUpdateHandler = (
|
|||
|
||||
export type AllToolCallsCompleteHandler = (
|
||||
completedToolCalls: CompletedToolCall[],
|
||||
) => void;
|
||||
) => Promise<void>;
|
||||
|
||||
export type ToolCallsUpdateHandler = (toolCalls: ToolCall[]) => void;
|
||||
|
||||
|
@ -244,6 +244,14 @@ export class CoreToolScheduler {
|
|||
private getPreferredEditor: () => EditorType | undefined;
|
||||
private config: Config;
|
||||
private onEditorClose: () => void;
|
||||
private isFinalizingToolCalls = false;
|
||||
private isScheduling = false;
|
||||
private requestQueue: Array<{
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[];
|
||||
signal: AbortSignal;
|
||||
resolve: () => void;
|
||||
reject: (reason?: Error) => void;
|
||||
}> = [];
|
||||
|
||||
constructor(options: CoreToolSchedulerOptions) {
|
||||
this.config = options.config;
|
||||
|
@ -455,9 +463,12 @@ export class CoreToolScheduler {
|
|||
}
|
||||
|
||||
private isRunning(): boolean {
|
||||
return this.toolCalls.some(
|
||||
(call) =>
|
||||
call.status === 'executing' || call.status === 'awaiting_approval',
|
||||
return (
|
||||
this.isFinalizingToolCalls ||
|
||||
this.toolCalls.some(
|
||||
(call) =>
|
||||
call.status === 'executing' || call.status === 'awaiting_approval',
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -475,150 +486,191 @@ export class CoreToolScheduler {
|
|||
}
|
||||
}
|
||||
|
||||
async schedule(
|
||||
schedule(
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
if (this.isRunning()) {
|
||||
throw new Error(
|
||||
'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).',
|
||||
);
|
||||
if (this.isRunning() || this.isScheduling) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const abortHandler = () => {
|
||||
// Find and remove the request from the queue
|
||||
const index = this.requestQueue.findIndex(
|
||||
(item) => item.request === request,
|
||||
);
|
||||
if (index > -1) {
|
||||
this.requestQueue.splice(index, 1);
|
||||
reject(new Error('Tool call cancelled while in queue.'));
|
||||
}
|
||||
};
|
||||
|
||||
signal.addEventListener('abort', abortHandler, { once: true });
|
||||
|
||||
this.requestQueue.push({
|
||||
request,
|
||||
signal,
|
||||
resolve: () => {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
resolve();
|
||||
},
|
||||
reject: (reason?: Error) => {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
reject(reason);
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
const requestsToProcess = Array.isArray(request) ? request : [request];
|
||||
const toolRegistry = await this.toolRegistry;
|
||||
return this._schedule(request, signal);
|
||||
}
|
||||
|
||||
const newToolCalls: ToolCall[] = requestsToProcess.map(
|
||||
(reqInfo): ToolCall => {
|
||||
const toolInstance = toolRegistry.getTool(reqInfo.name);
|
||||
if (!toolInstance) {
|
||||
return {
|
||||
status: 'error',
|
||||
request: reqInfo,
|
||||
response: createErrorResponse(
|
||||
reqInfo,
|
||||
new Error(`Tool "${reqInfo.name}" not found in registry.`),
|
||||
ToolErrorType.TOOL_NOT_REGISTERED,
|
||||
),
|
||||
durationMs: 0,
|
||||
};
|
||||
}
|
||||
|
||||
const invocationOrError = this.buildInvocation(
|
||||
toolInstance,
|
||||
reqInfo.args,
|
||||
private async _schedule(
|
||||
request: ToolCallRequestInfo | ToolCallRequestInfo[],
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
this.isScheduling = true;
|
||||
try {
|
||||
if (this.isRunning()) {
|
||||
throw new Error(
|
||||
'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).',
|
||||
);
|
||||
if (invocationOrError instanceof Error) {
|
||||
}
|
||||
const requestsToProcess = Array.isArray(request) ? request : [request];
|
||||
const toolRegistry = await this.toolRegistry;
|
||||
|
||||
const newToolCalls: ToolCall[] = requestsToProcess.map(
|
||||
(reqInfo): ToolCall => {
|
||||
const toolInstance = toolRegistry.getTool(reqInfo.name);
|
||||
if (!toolInstance) {
|
||||
return {
|
||||
status: 'error',
|
||||
request: reqInfo,
|
||||
response: createErrorResponse(
|
||||
reqInfo,
|
||||
new Error(`Tool "${reqInfo.name}" not found in registry.`),
|
||||
ToolErrorType.TOOL_NOT_REGISTERED,
|
||||
),
|
||||
durationMs: 0,
|
||||
};
|
||||
}
|
||||
|
||||
const invocationOrError = this.buildInvocation(
|
||||
toolInstance,
|
||||
reqInfo.args,
|
||||
);
|
||||
if (invocationOrError instanceof Error) {
|
||||
return {
|
||||
status: 'error',
|
||||
request: reqInfo,
|
||||
tool: toolInstance,
|
||||
response: createErrorResponse(
|
||||
reqInfo,
|
||||
invocationOrError,
|
||||
ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
),
|
||||
durationMs: 0,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
status: 'error',
|
||||
status: 'validating',
|
||||
request: reqInfo,
|
||||
tool: toolInstance,
|
||||
response: createErrorResponse(
|
||||
reqInfo,
|
||||
invocationOrError,
|
||||
ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
),
|
||||
durationMs: 0,
|
||||
invocation: invocationOrError,
|
||||
startTime: Date.now(),
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
this.toolCalls = this.toolCalls.concat(newToolCalls);
|
||||
this.notifyToolCallsUpdate();
|
||||
|
||||
for (const toolCall of newToolCalls) {
|
||||
if (toolCall.status !== 'validating') {
|
||||
continue;
|
||||
}
|
||||
|
||||
return {
|
||||
status: 'validating',
|
||||
request: reqInfo,
|
||||
tool: toolInstance,
|
||||
invocation: invocationOrError,
|
||||
startTime: Date.now(),
|
||||
};
|
||||
},
|
||||
);
|
||||
const { request: reqInfo, invocation } = toolCall;
|
||||
|
||||
this.toolCalls = this.toolCalls.concat(newToolCalls);
|
||||
this.notifyToolCallsUpdate();
|
||||
|
||||
for (const toolCall of newToolCalls) {
|
||||
if (toolCall.status !== 'validating') {
|
||||
continue;
|
||||
}
|
||||
|
||||
const { request: reqInfo, invocation } = toolCall;
|
||||
|
||||
try {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
} else {
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (confirmationDetails) {
|
||||
// Allow IDE to resolve confirmation
|
||||
if (
|
||||
confirmationDetails.type === 'edit' &&
|
||||
confirmationDetails.ideConfirmation
|
||||
) {
|
||||
confirmationDetails.ideConfirmation.then((resolution) => {
|
||||
if (resolution.status === 'accepted') {
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
signal,
|
||||
);
|
||||
} else {
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
signal,
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const originalOnConfirm = confirmationDetails.onConfirm;
|
||||
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
|
||||
...confirmationDetails,
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
) =>
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
originalOnConfirm,
|
||||
outcome,
|
||||
signal,
|
||||
payload,
|
||||
),
|
||||
};
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'awaiting_approval',
|
||||
wrappedConfirmationDetails,
|
||||
);
|
||||
} else {
|
||||
try {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
} else {
|
||||
const confirmationDetails =
|
||||
await invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (confirmationDetails) {
|
||||
// Allow IDE to resolve confirmation
|
||||
if (
|
||||
confirmationDetails.type === 'edit' &&
|
||||
confirmationDetails.ideConfirmation
|
||||
) {
|
||||
confirmationDetails.ideConfirmation.then((resolution) => {
|
||||
if (resolution.status === 'accepted') {
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
signal,
|
||||
);
|
||||
} else {
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
signal,
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const originalOnConfirm = confirmationDetails.onConfirm;
|
||||
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
|
||||
...confirmationDetails,
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
) =>
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
originalOnConfirm,
|
||||
outcome,
|
||||
signal,
|
||||
payload,
|
||||
),
|
||||
};
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'awaiting_approval',
|
||||
wrappedConfirmationDetails,
|
||||
);
|
||||
} else {
|
||||
this.setToolCallOutcome(
|
||||
reqInfo.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(reqInfo.callId, 'scheduled');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'error',
|
||||
createErrorResponse(
|
||||
reqInfo,
|
||||
error instanceof Error ? error : new Error(String(error)),
|
||||
ToolErrorType.UNHANDLED_EXCEPTION,
|
||||
),
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'error',
|
||||
createErrorResponse(
|
||||
reqInfo,
|
||||
error instanceof Error ? error : new Error(String(error)),
|
||||
ToolErrorType.UNHANDLED_EXCEPTION,
|
||||
),
|
||||
);
|
||||
}
|
||||
this.attemptExecutionOfScheduledCalls(signal);
|
||||
void this.checkAndNotifyCompletion();
|
||||
} finally {
|
||||
this.isScheduling = false;
|
||||
}
|
||||
this.attemptExecutionOfScheduledCalls(signal);
|
||||
this.checkAndNotifyCompletion();
|
||||
}
|
||||
|
||||
async handleConfirmationResponse(
|
||||
|
@ -822,7 +874,7 @@ export class CoreToolScheduler {
|
|||
}
|
||||
}
|
||||
|
||||
private checkAndNotifyCompletion(): void {
|
||||
private async checkAndNotifyCompletion(): Promise<void> {
|
||||
const allCallsAreTerminal = this.toolCalls.every(
|
||||
(call) =>
|
||||
call.status === 'success' ||
|
||||
|
@ -839,9 +891,18 @@ export class CoreToolScheduler {
|
|||
}
|
||||
|
||||
if (this.onAllToolCallsComplete) {
|
||||
this.onAllToolCallsComplete(completedCalls);
|
||||
this.isFinalizingToolCalls = true;
|
||||
await this.onAllToolCallsComplete(completedCalls);
|
||||
this.isFinalizingToolCalls = false;
|
||||
}
|
||||
this.notifyToolCallsUpdate();
|
||||
// After completion, process the next item in the queue.
|
||||
if (this.requestQueue.length > 0) {
|
||||
const next = this.requestQueue.shift()!;
|
||||
this._schedule(next.request, next.signal)
|
||||
.then(next.resolve)
|
||||
.catch(next.reject);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue