feat: auto-approve compatible pending tools when allow always is selected (#6519)

This commit is contained in:
Arya Gummadi 2025-08-19 18:22:41 -07:00 committed by GitHub
parent d587c6f104
commit 2a71c10b8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 234 additions and 0 deletions

View File

@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest';
import {
CoreToolScheduler,
ToolCall,
WaitingToolCall,
convertToFunctionResponse,
} from './coreToolScheduler.js';
import {
@ -26,6 +27,77 @@ import {
import { Part, PartListUnion } from '@google/genai';
import { MockModifiableTool, MockTool } from '../test-utils/tools.js';
class TestApprovalTool extends BaseDeclarativeTool<{ id: string }, ToolResult> {
static readonly Name = 'testApprovalTool';
constructor(private config: Config) {
super(
TestApprovalTool.Name,
'TestApprovalTool',
'A tool for testing approval logic',
Kind.Edit,
{
properties: { id: { type: 'string' } },
required: ['id'],
type: 'object',
},
);
}
protected createInvocation(params: {
id: string;
}): ToolInvocation<{ id: string }, ToolResult> {
return new TestApprovalInvocation(this.config, params);
}
}
class TestApprovalInvocation extends BaseToolInvocation<
{ id: string },
ToolResult
> {
constructor(
private config: Config,
params: { id: string },
) {
super(params);
}
getDescription(): string {
return `Test tool ${this.params.id}`;
}
override async shouldConfirmExecute(): Promise<
ToolCallConfirmationDetails | false
> {
// Need confirmation unless approval mode is AUTO_EDIT
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
}
return {
type: 'edit',
title: `Confirm Test Tool ${this.params.id}`,
fileName: `test-${this.params.id}.txt`,
filePath: `/test-${this.params.id}.txt`,
fileDiff: 'Test diff content',
originalContent: '',
newContent: 'Test content',
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
},
};
}
async execute(): Promise<ToolResult> {
return {
llmContent: `Executed test tool ${this.params.id}`,
returnDisplay: `Executed test tool ${this.params.id}`,
};
}
}
describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool();
@ -759,4 +831,131 @@ describe('CoreToolScheduler request queueing', () => {
// Ensure completion callbacks were called twice.
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
});
it('should auto-approve remaining tool calls when first tool call is approved with ProceedAlways', async () => {
let approvalMode = ApprovalMode.DEFAULT;
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getApprovalMode: () => approvalMode,
setApprovalMode: (mode: ApprovalMode) => {
approvalMode = mode;
},
} as unknown as Config;
const testTool = new TestApprovalTool(mockConfig);
const toolRegistry = {
getTool: () => testTool,
getFunctionDeclarations: () => [],
getFunctionDeclarationsFiltered: () => [],
registerTool: () => {},
discoverAllTools: async () => {},
discoverMcpTools: async () => {},
discoverToolsForServer: async () => {},
removeMcpToolsByServer: () => {},
getAllTools: () => [],
getToolsByServer: () => [],
tools: new Map(),
config: mockConfig,
mcpClientManager: undefined,
getToolByName: () => testTool,
getToolByDisplayName: () => testTool,
getTools: () => [],
discoverTools: async () => {},
discovery: {},
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const pendingConfirmations: Array<
(outcome: ToolConfirmationOutcome) => void
> = [];
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: toolRegistry as unknown as ToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate: (toolCalls) => {
onToolCallsUpdate(toolCalls);
// Capture confirmation handlers for awaiting_approval tools
toolCalls.forEach((call) => {
if (call.status === 'awaiting_approval') {
const waitingCall = call as WaitingToolCall;
if (waitingCall.confirmationDetails?.onConfirm) {
const originalHandler = pendingConfirmations.find(
(h) => h === waitingCall.confirmationDetails.onConfirm,
);
if (!originalHandler) {
pendingConfirmations.push(
waitingCall.confirmationDetails.onConfirm,
);
}
}
}
});
},
getPreferredEditor: () => 'vscode',
onEditorClose: vi.fn(),
});
const abortController = new AbortController();
// Schedule multiple tools that need confirmation
const requests = [
{
callId: '1',
name: 'testApprovalTool',
args: { id: 'first' },
isClientInitiated: false,
prompt_id: 'prompt-1',
},
{
callId: '2',
name: 'testApprovalTool',
args: { id: 'second' },
isClientInitiated: false,
prompt_id: 'prompt-2',
},
{
callId: '3',
name: 'testApprovalTool',
args: { id: 'third' },
isClientInitiated: false,
prompt_id: 'prompt-3',
},
];
await scheduler.schedule(requests, abortController.signal);
// Wait for all tools to be awaiting approval
await vi.waitFor(() => {
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
expect(calls?.length).toBe(3);
expect(calls?.every((call) => call.status === 'awaiting_approval')).toBe(
true,
);
});
expect(pendingConfirmations.length).toBe(3);
// Approve the first tool with ProceedAlways
const firstConfirmation = pendingConfirmations[0];
firstConfirmation(ToolConfirmationOutcome.ProceedAlways);
// Wait for all tools to be completed
await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalled();
const completedCalls = onAllToolCallsComplete.mock.calls.at(
-1,
)?.[0] as ToolCall[];
expect(completedCalls?.length).toBe(3);
expect(completedCalls?.every((call) => call.status === 'success')).toBe(
true,
);
});
// Verify approval mode was changed
expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT);
});
});

View File

@ -695,6 +695,10 @@ export class CoreToolScheduler {
await originalOnConfirm(outcome);
}
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
await this.autoApproveCompatiblePendingTools(signal, callId);
}
this.setToolCallOutcome(callId, outcome);
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
@ -928,4 +932,35 @@ export class CoreToolScheduler {
};
});
}
private async autoApproveCompatiblePendingTools(
signal: AbortSignal,
triggeringCallId: string,
): Promise<void> {
const pendingTools = this.toolCalls.filter(
(call) =>
call.status === 'awaiting_approval' &&
call.request.callId !== triggeringCallId,
) as WaitingToolCall[];
for (const pendingTool of pendingTools) {
try {
const stillNeedsConfirmation =
await pendingTool.invocation.shouldConfirmExecute(signal);
if (!stillNeedsConfirmation) {
this.setToolCallOutcome(
pendingTool.request.callId,
ToolConfirmationOutcome.ProceedAlways,
);
this.setStatusInternal(pendingTool.request.callId, 'scheduled');
}
} catch (error) {
console.error(
`Error checking confirmation for tool ${pendingTool.request.callId}:`,
error,
);
}
}
}
}