Add a request queue to the tool scheduler (#5845)

This commit is contained in:
Jacob MacDonald 2025-08-08 14:50:35 -07:00 committed by GitHub
parent 9ac62565a0
commit 69322e12e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 385 additions and 132 deletions

View File

@ -63,7 +63,7 @@ export type TrackedToolCall =
| TrackedCancelledToolCall; | TrackedCancelledToolCall;
export function useReactToolScheduler( export function useReactToolScheduler(
onComplete: (tools: CompletedToolCall[]) => void, onComplete: (tools: CompletedToolCall[]) => Promise<void>,
config: Config, config: Config,
setPendingHistoryItem: React.Dispatch< setPendingHistoryItem: React.Dispatch<
React.SetStateAction<HistoryItemWithoutId | null> React.SetStateAction<HistoryItemWithoutId | null>
@ -106,8 +106,8 @@ export function useReactToolScheduler(
); );
const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = useCallback( const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = useCallback(
(completedToolCalls) => { async (completedToolCalls) => {
onComplete(completedToolCalls); await onComplete(completedToolCalls);
}, },
[onComplete], [onComplete],
); );
@ -157,7 +157,7 @@ export function useReactToolScheduler(
request: ToolCallRequestInfo | ToolCallRequestInfo[], request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal, signal: AbortSignal,
) => { ) => {
scheduler.schedule(request, signal); void scheduler.schedule(request, signal);
}, },
[scheduler], [scheduler],
); );

View File

@ -592,3 +592,195 @@ describe('CoreToolScheduler YOLO mode', () => {
} }
}); });
}); });
describe('CoreToolScheduler request queueing', () => {
it('should queue a request if another is running', async () => {
let resolveFirstCall: (result: ToolResult) => void;
const firstCallPromise = new Promise<ToolResult>((resolve) => {
resolveFirstCall = resolve;
});
const mockTool = new MockTool();
mockTool.executeFn.mockImplementation(() => firstCallPromise);
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => declarativeTool,
getToolByName: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: Promise.resolve(toolRegistry as any),
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
const request1 = {
callId: '1',
name: 'mockTool',
args: { a: 1 },
isClientInitiated: false,
prompt_id: 'prompt-1',
};
const request2 = {
callId: '2',
name: 'mockTool',
args: { b: 2 },
isClientInitiated: false,
prompt_id: 'prompt-2',
};
// Schedule the first call, which will pause execution.
scheduler.schedule([request1], abortController.signal);
// Wait for the first call to be in the 'executing' state.
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.[0]?.status).toBe('executing');
});
// Schedule the second call while the first is "running".
const schedulePromise2 = scheduler.schedule(
[request2],
abortController.signal,
);
// Ensure the second tool call hasn't been executed yet.
expect(mockTool.executeFn).toHaveBeenCalledTimes(1);
expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 });
// Complete the first tool call.
resolveFirstCall!({
llmContent: 'First call complete',
returnDisplay: 'First call complete',
});
// Wait for the second schedule promise to resolve.
await schedulePromise2;
// Wait for the second call to be in the 'executing' state.
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.[0]?.status).toBe('executing');
});
// Now the second tool call should have been executed.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
// Let the second call finish.
const secondCallResult = {
llmContent: 'Second call complete',
returnDisplay: 'Second call complete',
};
// Since the mock is shared, we need to resolve the current promise.
// In a real scenario, a new promise would be created for the second call.
resolveFirstCall!(secondCallResult);
// Wait for the second completion.
await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
});
// Verify the completion callbacks were called correctly.
expect(onAllToolCallsComplete.mock.calls[0][0][0].status).toBe('success');
expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success');
});
it('should handle two synchronous calls to schedule', async () => {
const mockTool = new MockTool();
const declarativeTool = mockTool;
const toolRegistry = {
getTool: () => declarativeTool,
getToolByName: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => ApprovalMode.YOLO,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: Promise.resolve(toolRegistry as any),
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
const request1 = {
callId: '1',
name: 'mockTool',
args: { a: 1 },
isClientInitiated: false,
prompt_id: 'prompt-1',
};
const request2 = {
callId: '2',
name: 'mockTool',
args: { b: 2 },
isClientInitiated: false,
prompt_id: 'prompt-2',
};
// Schedule two calls synchronously.
const schedulePromise1 = scheduler.schedule(
[request1],
abortController.signal,
);
const schedulePromise2 = scheduler.schedule(
[request2],
abortController.signal,
);
// Wait for both promises to resolve.
await Promise.all([schedulePromise1, schedulePromise2]);
// Ensure the tool was called twice with the correct arguments.
expect(mockTool.executeFn).toHaveBeenCalledTimes(2);
expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 });
expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 });
// Ensure completion callbacks were called twice.
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
});
});

View File

@ -125,7 +125,7 @@ export type OutputUpdateHandler = (
export type AllToolCallsCompleteHandler = ( export type AllToolCallsCompleteHandler = (
completedToolCalls: CompletedToolCall[], completedToolCalls: CompletedToolCall[],
) => void; ) => Promise<void>;
export type ToolCallsUpdateHandler = (toolCalls: ToolCall[]) => void; export type ToolCallsUpdateHandler = (toolCalls: ToolCall[]) => void;
@ -244,6 +244,14 @@ export class CoreToolScheduler {
private getPreferredEditor: () => EditorType | undefined; private getPreferredEditor: () => EditorType | undefined;
private config: Config; private config: Config;
private onEditorClose: () => void; private onEditorClose: () => void;
private isFinalizingToolCalls = false;
private isScheduling = false;
private requestQueue: Array<{
request: ToolCallRequestInfo | ToolCallRequestInfo[];
signal: AbortSignal;
resolve: () => void;
reject: (reason?: Error) => void;
}> = [];
constructor(options: CoreToolSchedulerOptions) { constructor(options: CoreToolSchedulerOptions) {
this.config = options.config; this.config = options.config;
@ -455,9 +463,12 @@ export class CoreToolScheduler {
} }
private isRunning(): boolean { private isRunning(): boolean {
return this.toolCalls.some( return (
(call) => this.isFinalizingToolCalls ||
call.status === 'executing' || call.status === 'awaiting_approval', this.toolCalls.some(
(call) =>
call.status === 'executing' || call.status === 'awaiting_approval',
)
); );
} }
@ -475,150 +486,191 @@ export class CoreToolScheduler {
} }
} }
async schedule( schedule(
request: ToolCallRequestInfo | ToolCallRequestInfo[], request: ToolCallRequestInfo | ToolCallRequestInfo[],
signal: AbortSignal, signal: AbortSignal,
): Promise<void> { ): Promise<void> {
if (this.isRunning()) { if (this.isRunning() || this.isScheduling) {
throw new Error( return new Promise((resolve, reject) => {
'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).', const abortHandler = () => {
); // Find and remove the request from the queue
const index = this.requestQueue.findIndex(
(item) => item.request === request,
);
if (index > -1) {
this.requestQueue.splice(index, 1);
reject(new Error('Tool call cancelled while in queue.'));
}
};
signal.addEventListener('abort', abortHandler, { once: true });
this.requestQueue.push({
request,
signal,
resolve: () => {
signal.removeEventListener('abort', abortHandler);
resolve();
},
reject: (reason?: Error) => {
signal.removeEventListener('abort', abortHandler);
reject(reason);
},
});
});
} }
const requestsToProcess = Array.isArray(request) ? request : [request]; return this._schedule(request, signal);
const toolRegistry = await this.toolRegistry; }
const newToolCalls: ToolCall[] = requestsToProcess.map( private async _schedule(
(reqInfo): ToolCall => { request: ToolCallRequestInfo | ToolCallRequestInfo[],
const toolInstance = toolRegistry.getTool(reqInfo.name); signal: AbortSignal,
if (!toolInstance) { ): Promise<void> {
return { this.isScheduling = true;
status: 'error', try {
request: reqInfo, if (this.isRunning()) {
response: createErrorResponse( throw new Error(
reqInfo, 'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).',
new Error(`Tool "${reqInfo.name}" not found in registry.`),
ToolErrorType.TOOL_NOT_REGISTERED,
),
durationMs: 0,
};
}
const invocationOrError = this.buildInvocation(
toolInstance,
reqInfo.args,
); );
if (invocationOrError instanceof Error) { }
const requestsToProcess = Array.isArray(request) ? request : [request];
const toolRegistry = await this.toolRegistry;
const newToolCalls: ToolCall[] = requestsToProcess.map(
(reqInfo): ToolCall => {
const toolInstance = toolRegistry.getTool(reqInfo.name);
if (!toolInstance) {
return {
status: 'error',
request: reqInfo,
response: createErrorResponse(
reqInfo,
new Error(`Tool "${reqInfo.name}" not found in registry.`),
ToolErrorType.TOOL_NOT_REGISTERED,
),
durationMs: 0,
};
}
const invocationOrError = this.buildInvocation(
toolInstance,
reqInfo.args,
);
if (invocationOrError instanceof Error) {
return {
status: 'error',
request: reqInfo,
tool: toolInstance,
response: createErrorResponse(
reqInfo,
invocationOrError,
ToolErrorType.INVALID_TOOL_PARAMS,
),
durationMs: 0,
};
}
return { return {
status: 'error', status: 'validating',
request: reqInfo, request: reqInfo,
tool: toolInstance, tool: toolInstance,
response: createErrorResponse( invocation: invocationOrError,
reqInfo, startTime: Date.now(),
invocationOrError,
ToolErrorType.INVALID_TOOL_PARAMS,
),
durationMs: 0,
}; };
},
);
this.toolCalls = this.toolCalls.concat(newToolCalls);
this.notifyToolCallsUpdate();
for (const toolCall of newToolCalls) {
if (toolCall.status !== 'validating') {
continue;
} }
return { const { request: reqInfo, invocation } = toolCall;
status: 'validating',
request: reqInfo,
tool: toolInstance,
invocation: invocationOrError,
startTime: Date.now(),
};
},
);
this.toolCalls = this.toolCalls.concat(newToolCalls); try {
this.notifyToolCallsUpdate(); if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
for (const toolCall of newToolCalls) {
if (toolCall.status !== 'validating') {
continue;
}
const { request: reqInfo, invocation } = toolCall;
try {
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
} else {
const confirmationDetails =
await invocation.shouldConfirmExecute(signal);
if (confirmationDetails) {
// Allow IDE to resolve confirmation
if (
confirmationDetails.type === 'edit' &&
confirmationDetails.ideConfirmation
) {
confirmationDetails.ideConfirmation.then((resolution) => {
if (resolution.status === 'accepted') {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
signal,
);
} else {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.Cancel,
signal,
);
}
});
}
const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails,
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
this.handleConfirmationResponse(
reqInfo.callId,
originalOnConfirm,
outcome,
signal,
payload,
),
};
this.setStatusInternal(
reqInfo.callId,
'awaiting_approval',
wrappedConfirmationDetails,
);
} else {
this.setToolCallOutcome( this.setToolCallOutcome(
reqInfo.callId, reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways, ToolConfirmationOutcome.ProceedAlways,
); );
this.setStatusInternal(reqInfo.callId, 'scheduled'); this.setStatusInternal(reqInfo.callId, 'scheduled');
} else {
const confirmationDetails =
await invocation.shouldConfirmExecute(signal);
if (confirmationDetails) {
// Allow IDE to resolve confirmation
if (
confirmationDetails.type === 'edit' &&
confirmationDetails.ideConfirmation
) {
confirmationDetails.ideConfirmation.then((resolution) => {
if (resolution.status === 'accepted') {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
signal,
);
} else {
this.handleConfirmationResponse(
reqInfo.callId,
confirmationDetails.onConfirm,
ToolConfirmationOutcome.Cancel,
signal,
);
}
});
}
const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails,
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
this.handleConfirmationResponse(
reqInfo.callId,
originalOnConfirm,
outcome,
signal,
payload,
),
};
this.setStatusInternal(
reqInfo.callId,
'awaiting_approval',
wrappedConfirmationDetails,
);
} else {
this.setToolCallOutcome(
reqInfo.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(reqInfo.callId, 'scheduled');
}
} }
} catch (error) {
this.setStatusInternal(
reqInfo.callId,
'error',
createErrorResponse(
reqInfo,
error instanceof Error ? error : new Error(String(error)),
ToolErrorType.UNHANDLED_EXCEPTION,
),
);
} }
} catch (error) {
this.setStatusInternal(
reqInfo.callId,
'error',
createErrorResponse(
reqInfo,
error instanceof Error ? error : new Error(String(error)),
ToolErrorType.UNHANDLED_EXCEPTION,
),
);
} }
this.attemptExecutionOfScheduledCalls(signal);
void this.checkAndNotifyCompletion();
} finally {
this.isScheduling = false;
} }
this.attemptExecutionOfScheduledCalls(signal);
this.checkAndNotifyCompletion();
} }
async handleConfirmationResponse( async handleConfirmationResponse(
@ -822,7 +874,7 @@ export class CoreToolScheduler {
} }
} }
private checkAndNotifyCompletion(): void { private async checkAndNotifyCompletion(): Promise<void> {
const allCallsAreTerminal = this.toolCalls.every( const allCallsAreTerminal = this.toolCalls.every(
(call) => (call) =>
call.status === 'success' || call.status === 'success' ||
@ -839,9 +891,18 @@ export class CoreToolScheduler {
} }
if (this.onAllToolCallsComplete) { if (this.onAllToolCallsComplete) {
this.onAllToolCallsComplete(completedCalls); this.isFinalizingToolCalls = true;
await this.onAllToolCallsComplete(completedCalls);
this.isFinalizingToolCalls = false;
} }
this.notifyToolCallsUpdate(); this.notifyToolCallsUpdate();
// After completion, process the next item in the queue.
if (this.requestQueue.length > 0) {
const next = this.requestQueue.shift()!;
this._schedule(next.request, next.signal)
.then(next.resolve)
.catch(next.reject);
}
} }
} }