diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts index c6b802fc..93e05387 100644 --- a/packages/cli/src/ui/hooks/useReactToolScheduler.ts +++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts @@ -63,7 +63,7 @@ export type TrackedToolCall = | TrackedCancelledToolCall; export function useReactToolScheduler( - onComplete: (tools: CompletedToolCall[]) => void, + onComplete: (tools: CompletedToolCall[]) => Promise, config: Config, setPendingHistoryItem: React.Dispatch< React.SetStateAction @@ -106,8 +106,8 @@ export function useReactToolScheduler( ); const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = useCallback( - (completedToolCalls) => { - onComplete(completedToolCalls); + async (completedToolCalls) => { + await onComplete(completedToolCalls); }, [onComplete], ); @@ -157,7 +157,7 @@ export function useReactToolScheduler( request: ToolCallRequestInfo | ToolCallRequestInfo[], signal: AbortSignal, ) => { - scheduler.schedule(request, signal); + void scheduler.schedule(request, signal); }, [scheduler], ); diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index a65443f8..a3a25707 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -592,3 +592,195 @@ describe('CoreToolScheduler YOLO mode', () => { } }); }); + +describe('CoreToolScheduler request queueing', () => { + it('should queue a request if another is running', async () => { + let resolveFirstCall: (result: ToolResult) => void; + const firstCallPromise = new Promise((resolve) => { + resolveFirstCall = resolve; + }); + + const mockTool = new MockTool(); + mockTool.executeFn.mockImplementation(() => firstCallPromise); + const declarativeTool = mockTool; + + const toolRegistry = { + getTool: () => declarativeTool, + getToolByName: () => declarativeTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {} as any, + registerTool: () => {}, + getToolByDisplayName: () => declarativeTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + }; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts + } as unknown as Config; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + toolRegistry: Promise.resolve(toolRegistry as any), + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const abortController = new AbortController(); + const request1 = { + callId: '1', + name: 'mockTool', + args: { a: 1 }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + const request2 = { + callId: '2', + name: 'mockTool', + args: { b: 2 }, + isClientInitiated: false, + prompt_id: 'prompt-2', + }; + + // Schedule the first call, which will pause execution. + scheduler.schedule([request1], abortController.signal); + + // Wait for the first call to be in the 'executing' state. + await vi.waitFor(() => { + const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[]; + expect(calls?.[0]?.status).toBe('executing'); + }); + + // Schedule the second call while the first is "running". + const schedulePromise2 = scheduler.schedule( + [request2], + abortController.signal, + ); + + // Ensure the second tool call hasn't been executed yet. + expect(mockTool.executeFn).toHaveBeenCalledTimes(1); + expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 }); + + // Complete the first tool call. + resolveFirstCall!({ + llmContent: 'First call complete', + returnDisplay: 'First call complete', + }); + + // Wait for the second schedule promise to resolve. + await schedulePromise2; + + // Wait for the second call to be in the 'executing' state. + await vi.waitFor(() => { + const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[]; + expect(calls?.[0]?.status).toBe('executing'); + }); + + // Now the second tool call should have been executed. + expect(mockTool.executeFn).toHaveBeenCalledTimes(2); + expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 }); + + // Let the second call finish. + const secondCallResult = { + llmContent: 'Second call complete', + returnDisplay: 'Second call complete', + }; + // Since the mock is shared, we need to resolve the current promise. + // In a real scenario, a new promise would be created for the second call. + resolveFirstCall!(secondCallResult); + + // Wait for the second completion. + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2); + }); + + // Verify the completion callbacks were called correctly. + expect(onAllToolCallsComplete.mock.calls[0][0][0].status).toBe('success'); + expect(onAllToolCallsComplete.mock.calls[1][0][0].status).toBe('success'); + }); + + it('should handle two synchronous calls to schedule', async () => { + const mockTool = new MockTool(); + const declarativeTool = mockTool; + const toolRegistry = { + getTool: () => declarativeTool, + getToolByName: () => declarativeTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {} as any, + registerTool: () => {}, + getToolByDisplayName: () => declarativeTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + }; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.YOLO, + } as unknown as Config; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + toolRegistry: Promise.resolve(toolRegistry as any), + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const abortController = new AbortController(); + const request1 = { + callId: '1', + name: 'mockTool', + args: { a: 1 }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + const request2 = { + callId: '2', + name: 'mockTool', + args: { b: 2 }, + isClientInitiated: false, + prompt_id: 'prompt-2', + }; + + // Schedule two calls synchronously. + const schedulePromise1 = scheduler.schedule( + [request1], + abortController.signal, + ); + const schedulePromise2 = scheduler.schedule( + [request2], + abortController.signal, + ); + + // Wait for both promises to resolve. + await Promise.all([schedulePromise1, schedulePromise2]); + + // Ensure the tool was called twice with the correct arguments. + expect(mockTool.executeFn).toHaveBeenCalledTimes(2); + expect(mockTool.executeFn).toHaveBeenCalledWith({ a: 1 }); + expect(mockTool.executeFn).toHaveBeenCalledWith({ b: 2 }); + + // Ensure completion callbacks were called twice. + expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2); + }); +}); diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 6f098ae3..00ff5c55 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -125,7 +125,7 @@ export type OutputUpdateHandler = ( export type AllToolCallsCompleteHandler = ( completedToolCalls: CompletedToolCall[], -) => void; +) => Promise; export type ToolCallsUpdateHandler = (toolCalls: ToolCall[]) => void; @@ -244,6 +244,14 @@ export class CoreToolScheduler { private getPreferredEditor: () => EditorType | undefined; private config: Config; private onEditorClose: () => void; + private isFinalizingToolCalls = false; + private isScheduling = false; + private requestQueue: Array<{ + request: ToolCallRequestInfo | ToolCallRequestInfo[]; + signal: AbortSignal; + resolve: () => void; + reject: (reason?: Error) => void; + }> = []; constructor(options: CoreToolSchedulerOptions) { this.config = options.config; @@ -455,9 +463,12 @@ export class CoreToolScheduler { } private isRunning(): boolean { - return this.toolCalls.some( - (call) => - call.status === 'executing' || call.status === 'awaiting_approval', + return ( + this.isFinalizingToolCalls || + this.toolCalls.some( + (call) => + call.status === 'executing' || call.status === 'awaiting_approval', + ) ); } @@ -475,150 +486,191 @@ export class CoreToolScheduler { } } - async schedule( + schedule( request: ToolCallRequestInfo | ToolCallRequestInfo[], signal: AbortSignal, ): Promise { - if (this.isRunning()) { - throw new Error( - 'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).', - ); + if (this.isRunning() || this.isScheduling) { + return new Promise((resolve, reject) => { + const abortHandler = () => { + // Find and remove the request from the queue + const index = this.requestQueue.findIndex( + (item) => item.request === request, + ); + if (index > -1) { + this.requestQueue.splice(index, 1); + reject(new Error('Tool call cancelled while in queue.')); + } + }; + + signal.addEventListener('abort', abortHandler, { once: true }); + + this.requestQueue.push({ + request, + signal, + resolve: () => { + signal.removeEventListener('abort', abortHandler); + resolve(); + }, + reject: (reason?: Error) => { + signal.removeEventListener('abort', abortHandler); + reject(reason); + }, + }); + }); } - const requestsToProcess = Array.isArray(request) ? request : [request]; - const toolRegistry = await this.toolRegistry; + return this._schedule(request, signal); + } - const newToolCalls: ToolCall[] = requestsToProcess.map( - (reqInfo): ToolCall => { - const toolInstance = toolRegistry.getTool(reqInfo.name); - if (!toolInstance) { - return { - status: 'error', - request: reqInfo, - response: createErrorResponse( - reqInfo, - new Error(`Tool "${reqInfo.name}" not found in registry.`), - ToolErrorType.TOOL_NOT_REGISTERED, - ), - durationMs: 0, - }; - } - - const invocationOrError = this.buildInvocation( - toolInstance, - reqInfo.args, + private async _schedule( + request: ToolCallRequestInfo | ToolCallRequestInfo[], + signal: AbortSignal, + ): Promise { + this.isScheduling = true; + try { + if (this.isRunning()) { + throw new Error( + 'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).', ); - if (invocationOrError instanceof Error) { + } + const requestsToProcess = Array.isArray(request) ? request : [request]; + const toolRegistry = await this.toolRegistry; + + const newToolCalls: ToolCall[] = requestsToProcess.map( + (reqInfo): ToolCall => { + const toolInstance = toolRegistry.getTool(reqInfo.name); + if (!toolInstance) { + return { + status: 'error', + request: reqInfo, + response: createErrorResponse( + reqInfo, + new Error(`Tool "${reqInfo.name}" not found in registry.`), + ToolErrorType.TOOL_NOT_REGISTERED, + ), + durationMs: 0, + }; + } + + const invocationOrError = this.buildInvocation( + toolInstance, + reqInfo.args, + ); + if (invocationOrError instanceof Error) { + return { + status: 'error', + request: reqInfo, + tool: toolInstance, + response: createErrorResponse( + reqInfo, + invocationOrError, + ToolErrorType.INVALID_TOOL_PARAMS, + ), + durationMs: 0, + }; + } + return { - status: 'error', + status: 'validating', request: reqInfo, tool: toolInstance, - response: createErrorResponse( - reqInfo, - invocationOrError, - ToolErrorType.INVALID_TOOL_PARAMS, - ), - durationMs: 0, + invocation: invocationOrError, + startTime: Date.now(), }; + }, + ); + + this.toolCalls = this.toolCalls.concat(newToolCalls); + this.notifyToolCallsUpdate(); + + for (const toolCall of newToolCalls) { + if (toolCall.status !== 'validating') { + continue; } - return { - status: 'validating', - request: reqInfo, - tool: toolInstance, - invocation: invocationOrError, - startTime: Date.now(), - }; - }, - ); + const { request: reqInfo, invocation } = toolCall; - this.toolCalls = this.toolCalls.concat(newToolCalls); - this.notifyToolCallsUpdate(); - - for (const toolCall of newToolCalls) { - if (toolCall.status !== 'validating') { - continue; - } - - const { request: reqInfo, invocation } = toolCall; - - try { - if (this.config.getApprovalMode() === ApprovalMode.YOLO) { - this.setToolCallOutcome( - reqInfo.callId, - ToolConfirmationOutcome.ProceedAlways, - ); - this.setStatusInternal(reqInfo.callId, 'scheduled'); - } else { - const confirmationDetails = - await invocation.shouldConfirmExecute(signal); - - if (confirmationDetails) { - // Allow IDE to resolve confirmation - if ( - confirmationDetails.type === 'edit' && - confirmationDetails.ideConfirmation - ) { - confirmationDetails.ideConfirmation.then((resolution) => { - if (resolution.status === 'accepted') { - this.handleConfirmationResponse( - reqInfo.callId, - confirmationDetails.onConfirm, - ToolConfirmationOutcome.ProceedOnce, - signal, - ); - } else { - this.handleConfirmationResponse( - reqInfo.callId, - confirmationDetails.onConfirm, - ToolConfirmationOutcome.Cancel, - signal, - ); - } - }); - } - - const originalOnConfirm = confirmationDetails.onConfirm; - const wrappedConfirmationDetails: ToolCallConfirmationDetails = { - ...confirmationDetails, - onConfirm: ( - outcome: ToolConfirmationOutcome, - payload?: ToolConfirmationPayload, - ) => - this.handleConfirmationResponse( - reqInfo.callId, - originalOnConfirm, - outcome, - signal, - payload, - ), - }; - this.setStatusInternal( - reqInfo.callId, - 'awaiting_approval', - wrappedConfirmationDetails, - ); - } else { + try { + if (this.config.getApprovalMode() === ApprovalMode.YOLO) { this.setToolCallOutcome( reqInfo.callId, ToolConfirmationOutcome.ProceedAlways, ); this.setStatusInternal(reqInfo.callId, 'scheduled'); + } else { + const confirmationDetails = + await invocation.shouldConfirmExecute(signal); + + if (confirmationDetails) { + // Allow IDE to resolve confirmation + if ( + confirmationDetails.type === 'edit' && + confirmationDetails.ideConfirmation + ) { + confirmationDetails.ideConfirmation.then((resolution) => { + if (resolution.status === 'accepted') { + this.handleConfirmationResponse( + reqInfo.callId, + confirmationDetails.onConfirm, + ToolConfirmationOutcome.ProceedOnce, + signal, + ); + } else { + this.handleConfirmationResponse( + reqInfo.callId, + confirmationDetails.onConfirm, + ToolConfirmationOutcome.Cancel, + signal, + ); + } + }); + } + + const originalOnConfirm = confirmationDetails.onConfirm; + const wrappedConfirmationDetails: ToolCallConfirmationDetails = { + ...confirmationDetails, + onConfirm: ( + outcome: ToolConfirmationOutcome, + payload?: ToolConfirmationPayload, + ) => + this.handleConfirmationResponse( + reqInfo.callId, + originalOnConfirm, + outcome, + signal, + payload, + ), + }; + this.setStatusInternal( + reqInfo.callId, + 'awaiting_approval', + wrappedConfirmationDetails, + ); + } else { + this.setToolCallOutcome( + reqInfo.callId, + ToolConfirmationOutcome.ProceedAlways, + ); + this.setStatusInternal(reqInfo.callId, 'scheduled'); + } } + } catch (error) { + this.setStatusInternal( + reqInfo.callId, + 'error', + createErrorResponse( + reqInfo, + error instanceof Error ? error : new Error(String(error)), + ToolErrorType.UNHANDLED_EXCEPTION, + ), + ); } - } catch (error) { - this.setStatusInternal( - reqInfo.callId, - 'error', - createErrorResponse( - reqInfo, - error instanceof Error ? error : new Error(String(error)), - ToolErrorType.UNHANDLED_EXCEPTION, - ), - ); } + this.attemptExecutionOfScheduledCalls(signal); + void this.checkAndNotifyCompletion(); + } finally { + this.isScheduling = false; } - this.attemptExecutionOfScheduledCalls(signal); - this.checkAndNotifyCompletion(); } async handleConfirmationResponse( @@ -822,7 +874,7 @@ export class CoreToolScheduler { } } - private checkAndNotifyCompletion(): void { + private async checkAndNotifyCompletion(): Promise { const allCallsAreTerminal = this.toolCalls.every( (call) => call.status === 'success' || @@ -839,9 +891,18 @@ export class CoreToolScheduler { } if (this.onAllToolCallsComplete) { - this.onAllToolCallsComplete(completedCalls); + this.isFinalizingToolCalls = true; + await this.onAllToolCallsComplete(completedCalls); + this.isFinalizingToolCalls = false; } this.notifyToolCallsUpdate(); + // After completion, process the next item in the queue. + if (this.requestQueue.length > 0) { + const next = this.requestQueue.shift()!; + this._schedule(next.request, next.signal) + .then(next.resolve) + .catch(next.reject); + } } }