feat: auto-approve compatible pending tools when allow always is selected (#6519)
This commit is contained in:
parent
d587c6f104
commit
2a71c10b8a
|
@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest';
|
|||
import {
|
||||
CoreToolScheduler,
|
||||
ToolCall,
|
||||
WaitingToolCall,
|
||||
convertToFunctionResponse,
|
||||
} from './coreToolScheduler.js';
|
||||
import {
|
||||
|
@ -26,6 +27,77 @@ import {
|
|||
import { Part, PartListUnion } from '@google/genai';
|
||||
import { MockModifiableTool, MockTool } from '../test-utils/tools.js';
|
||||
|
||||
class TestApprovalTool extends BaseDeclarativeTool<{ id: string }, ToolResult> {
|
||||
static readonly Name = 'testApprovalTool';
|
||||
|
||||
constructor(private config: Config) {
|
||||
super(
|
||||
TestApprovalTool.Name,
|
||||
'TestApprovalTool',
|
||||
'A tool for testing approval logic',
|
||||
Kind.Edit,
|
||||
{
|
||||
properties: { id: { type: 'string' } },
|
||||
required: ['id'],
|
||||
type: 'object',
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(params: {
|
||||
id: string;
|
||||
}): ToolInvocation<{ id: string }, ToolResult> {
|
||||
return new TestApprovalInvocation(this.config, params);
|
||||
}
|
||||
}
|
||||
|
||||
class TestApprovalInvocation extends BaseToolInvocation<
|
||||
{ id: string },
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
private config: Config,
|
||||
params: { id: string },
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
return `Test tool ${this.params.id}`;
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(): Promise<
|
||||
ToolCallConfirmationDetails | false
|
||||
> {
|
||||
// Need confirmation unless approval mode is AUTO_EDIT
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'edit',
|
||||
title: `Confirm Test Tool ${this.params.id}`,
|
||||
fileName: `test-${this.params.id}.txt`,
|
||||
filePath: `/test-${this.params.id}.txt`,
|
||||
fileDiff: 'Test diff content',
|
||||
originalContent: '',
|
||||
newContent: 'Test content',
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async execute(): Promise<ToolResult> {
|
||||
return {
|
||||
llmContent: `Executed test tool ${this.params.id}`,
|
||||
returnDisplay: `Executed test tool ${this.params.id}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
describe('CoreToolScheduler', () => {
|
||||
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
|
||||
const mockTool = new MockTool();
|
||||
|
@ -759,4 +831,131 @@ describe('CoreToolScheduler request queueing', () => {
|
|||
// Ensure completion callbacks were called twice.
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should auto-approve remaining tool calls when first tool call is approved with ProceedAlways', async () => {
|
||||
let approvalMode = ApprovalMode.DEFAULT;
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => approvalMode,
|
||||
setApprovalMode: (mode: ApprovalMode) => {
|
||||
approvalMode = mode;
|
||||
},
|
||||
} as unknown as Config;
|
||||
|
||||
const testTool = new TestApprovalTool(mockConfig);
|
||||
const toolRegistry = {
|
||||
getTool: () => testTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
getFunctionDeclarationsFiltered: () => [],
|
||||
registerTool: () => {},
|
||||
discoverAllTools: async () => {},
|
||||
discoverMcpTools: async () => {},
|
||||
discoverToolsForServer: async () => {},
|
||||
removeMcpToolsByServer: () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
tools: new Map(),
|
||||
config: mockConfig,
|
||||
mcpClientManager: undefined,
|
||||
getToolByName: () => testTool,
|
||||
getToolByDisplayName: () => testTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
discovery: {},
|
||||
};
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
const pendingConfirmations: Array<
|
||||
(outcome: ToolConfirmationOutcome) => void
|
||||
> = [];
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
toolRegistry: toolRegistry as unknown as ToolRegistry,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate: (toolCalls) => {
|
||||
onToolCallsUpdate(toolCalls);
|
||||
// Capture confirmation handlers for awaiting_approval tools
|
||||
toolCalls.forEach((call) => {
|
||||
if (call.status === 'awaiting_approval') {
|
||||
const waitingCall = call as WaitingToolCall;
|
||||
if (waitingCall.confirmationDetails?.onConfirm) {
|
||||
const originalHandler = pendingConfirmations.find(
|
||||
(h) => h === waitingCall.confirmationDetails.onConfirm,
|
||||
);
|
||||
if (!originalHandler) {
|
||||
pendingConfirmations.push(
|
||||
waitingCall.confirmationDetails.onConfirm,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
|
||||
// Schedule multiple tools that need confirmation
|
||||
const requests = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'testApprovalTool',
|
||||
args: { id: 'first' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
},
|
||||
{
|
||||
callId: '2',
|
||||
name: 'testApprovalTool',
|
||||
args: { id: 'second' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-2',
|
||||
},
|
||||
{
|
||||
callId: '3',
|
||||
name: 'testApprovalTool',
|
||||
args: { id: 'third' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-3',
|
||||
},
|
||||
];
|
||||
|
||||
await scheduler.schedule(requests, abortController.signal);
|
||||
|
||||
// Wait for all tools to be awaiting approval
|
||||
await vi.waitFor(() => {
|
||||
const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
|
||||
expect(calls?.length).toBe(3);
|
||||
expect(calls?.every((call) => call.status === 'awaiting_approval')).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
expect(pendingConfirmations.length).toBe(3);
|
||||
|
||||
// Approve the first tool with ProceedAlways
|
||||
const firstConfirmation = pendingConfirmations[0];
|
||||
firstConfirmation(ToolConfirmationOutcome.ProceedAlways);
|
||||
|
||||
// Wait for all tools to be completed
|
||||
await vi.waitFor(() => {
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
const completedCalls = onAllToolCallsComplete.mock.calls.at(
|
||||
-1,
|
||||
)?.[0] as ToolCall[];
|
||||
expect(completedCalls?.length).toBe(3);
|
||||
expect(completedCalls?.every((call) => call.status === 'success')).toBe(
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
// Verify approval mode was changed
|
||||
expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -695,6 +695,10 @@ export class CoreToolScheduler {
|
|||
await originalOnConfirm(outcome);
|
||||
}
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
await this.autoApproveCompatiblePendingTools(signal, callId);
|
||||
}
|
||||
|
||||
this.setToolCallOutcome(callId, outcome);
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
|
||||
|
@ -928,4 +932,35 @@ export class CoreToolScheduler {
|
|||
};
|
||||
});
|
||||
}
|
||||
|
||||
private async autoApproveCompatiblePendingTools(
|
||||
signal: AbortSignal,
|
||||
triggeringCallId: string,
|
||||
): Promise<void> {
|
||||
const pendingTools = this.toolCalls.filter(
|
||||
(call) =>
|
||||
call.status === 'awaiting_approval' &&
|
||||
call.request.callId !== triggeringCallId,
|
||||
) as WaitingToolCall[];
|
||||
|
||||
for (const pendingTool of pendingTools) {
|
||||
try {
|
||||
const stillNeedsConfirmation =
|
||||
await pendingTool.invocation.shouldConfirmExecute(signal);
|
||||
|
||||
if (!stillNeedsConfirmation) {
|
||||
this.setToolCallOutcome(
|
||||
pendingTool.request.callId,
|
||||
ToolConfirmationOutcome.ProceedAlways,
|
||||
);
|
||||
this.setStatusInternal(pendingTool.request.callId, 'scheduled');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error checking confirmation for tool ${pendingTool.request.callId}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue