feat: Handle inline content modification in tool scheduler (#2883)
This commit is contained in:
parent
2b8a565f89
commit
9211905ff1
|
@ -16,11 +16,14 @@ import {
|
||||||
BaseTool,
|
BaseTool,
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
ToolConfirmationOutcome,
|
ToolConfirmationOutcome,
|
||||||
|
ToolConfirmationPayload,
|
||||||
ToolResult,
|
ToolResult,
|
||||||
Config,
|
Config,
|
||||||
} from '../index.js';
|
} from '../index.js';
|
||||||
import { Part, PartListUnion } from '@google/genai';
|
import { Part, PartListUnion } from '@google/genai';
|
||||||
|
|
||||||
|
import { ModifiableTool, ModifyContext } from '../tools/modifiable-tool.js';
|
||||||
|
|
||||||
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
|
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
|
||||||
shouldConfirm = false;
|
shouldConfirm = false;
|
||||||
executeFn = vi.fn();
|
executeFn = vi.fn();
|
||||||
|
@ -54,6 +57,47 @@ class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MockModifiableTool
|
||||||
|
extends MockTool
|
||||||
|
implements ModifiableTool<Record<string, unknown>>
|
||||||
|
{
|
||||||
|
constructor(name = 'mockModifiableTool') {
|
||||||
|
super(name);
|
||||||
|
this.shouldConfirm = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
getModifyContext(
|
||||||
|
_abortSignal: AbortSignal,
|
||||||
|
): ModifyContext<Record<string, unknown>> {
|
||||||
|
return {
|
||||||
|
getFilePath: () => 'test.txt',
|
||||||
|
getCurrentContent: async () => 'old content',
|
||||||
|
getProposedContent: async () => 'new content',
|
||||||
|
createUpdatedParams: (
|
||||||
|
_oldContent: string,
|
||||||
|
modifiedProposedContent: string,
|
||||||
|
_originalParams: Record<string, unknown>,
|
||||||
|
) => ({ newContent: modifiedProposedContent }),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async shouldConfirmExecute(
|
||||||
|
_params: Record<string, unknown>,
|
||||||
|
_abortSignal: AbortSignal,
|
||||||
|
): Promise<ToolCallConfirmationDetails | false> {
|
||||||
|
if (this.shouldConfirm) {
|
||||||
|
return {
|
||||||
|
type: 'edit',
|
||||||
|
title: 'Confirm Mock Tool',
|
||||||
|
fileName: 'test.txt',
|
||||||
|
fileDiff: 'diff',
|
||||||
|
onConfirm: async () => {},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
describe('CoreToolScheduler', () => {
|
describe('CoreToolScheduler', () => {
|
||||||
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
|
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
|
||||||
const mockTool = new MockTool();
|
const mockTool = new MockTool();
|
||||||
|
@ -122,6 +166,76 @@ describe('CoreToolScheduler', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('CoreToolScheduler with payload', () => {
|
||||||
|
it('should update args and diff and execute tool when payload is provided', async () => {
|
||||||
|
const mockTool = new MockModifiableTool();
|
||||||
|
const toolRegistry = {
|
||||||
|
getTool: () => mockTool,
|
||||||
|
getFunctionDeclarations: () => [],
|
||||||
|
tools: new Map(),
|
||||||
|
discovery: {} as any,
|
||||||
|
registerTool: () => {},
|
||||||
|
getToolByName: () => mockTool,
|
||||||
|
getToolByDisplayName: () => mockTool,
|
||||||
|
getTools: () => [],
|
||||||
|
discoverTools: async () => {},
|
||||||
|
getAllTools: () => [],
|
||||||
|
getToolsByServer: () => [],
|
||||||
|
};
|
||||||
|
|
||||||
|
const onAllToolCallsComplete = vi.fn();
|
||||||
|
const onToolCallsUpdate = vi.fn();
|
||||||
|
|
||||||
|
const mockConfig = {
|
||||||
|
getSessionId: () => 'test-session-id',
|
||||||
|
getUsageStatisticsEnabled: () => true,
|
||||||
|
getDebugMode: () => false,
|
||||||
|
} as unknown as Config;
|
||||||
|
|
||||||
|
const scheduler = new CoreToolScheduler({
|
||||||
|
config: mockConfig,
|
||||||
|
toolRegistry: Promise.resolve(toolRegistry as any),
|
||||||
|
onAllToolCallsComplete,
|
||||||
|
onToolCallsUpdate,
|
||||||
|
getPreferredEditor: () => 'vscode',
|
||||||
|
});
|
||||||
|
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const request = {
|
||||||
|
callId: '1',
|
||||||
|
name: 'mockModifiableTool',
|
||||||
|
args: {},
|
||||||
|
isClientInitiated: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
await scheduler.schedule([request], abortController.signal);
|
||||||
|
|
||||||
|
const confirmationDetails = await mockTool.shouldConfirmExecute(
|
||||||
|
{},
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (confirmationDetails) {
|
||||||
|
const payload: ToolConfirmationPayload = { newContent: 'final version' };
|
||||||
|
await scheduler.handleConfirmationResponse(
|
||||||
|
'1',
|
||||||
|
confirmationDetails.onConfirm,
|
||||||
|
ToolConfirmationOutcome.ProceedOnce,
|
||||||
|
abortController.signal,
|
||||||
|
payload,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||||
|
const completedCalls = onAllToolCallsComplete.mock
|
||||||
|
.calls[0][0] as ToolCall[];
|
||||||
|
expect(completedCalls[0].status).toBe('success');
|
||||||
|
expect(mockTool.executeFn).toHaveBeenCalledWith({
|
||||||
|
newContent: 'final version',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('convertToFunctionResponse', () => {
|
describe('convertToFunctionResponse', () => {
|
||||||
const toolName = 'testTool';
|
const toolName = 'testTool';
|
||||||
const callId = 'call1';
|
const callId = 'call1';
|
||||||
|
|
|
@ -17,6 +17,7 @@ import {
|
||||||
Config,
|
Config,
|
||||||
logToolCall,
|
logToolCall,
|
||||||
ToolCallEvent,
|
ToolCallEvent,
|
||||||
|
ToolConfirmationPayload,
|
||||||
} from '../index.js';
|
} from '../index.js';
|
||||||
import { Part, PartListUnion } from '@google/genai';
|
import { Part, PartListUnion } from '@google/genai';
|
||||||
import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js';
|
import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js';
|
||||||
|
@ -25,6 +26,7 @@ import {
|
||||||
ModifyContext,
|
ModifyContext,
|
||||||
modifyWithEditor,
|
modifyWithEditor,
|
||||||
} from '../tools/modifiable-tool.js';
|
} from '../tools/modifiable-tool.js';
|
||||||
|
import * as Diff from 'diff';
|
||||||
|
|
||||||
export type ValidatingToolCall = {
|
export type ValidatingToolCall = {
|
||||||
status: 'validating';
|
status: 'validating';
|
||||||
|
@ -455,12 +457,16 @@ export class CoreToolScheduler {
|
||||||
const originalOnConfirm = confirmationDetails.onConfirm;
|
const originalOnConfirm = confirmationDetails.onConfirm;
|
||||||
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
|
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
|
||||||
...confirmationDetails,
|
...confirmationDetails,
|
||||||
onConfirm: (outcome: ToolConfirmationOutcome) =>
|
onConfirm: (
|
||||||
|
outcome: ToolConfirmationOutcome,
|
||||||
|
payload?: ToolConfirmationPayload,
|
||||||
|
) =>
|
||||||
this.handleConfirmationResponse(
|
this.handleConfirmationResponse(
|
||||||
reqInfo.callId,
|
reqInfo.callId,
|
||||||
originalOnConfirm,
|
originalOnConfirm,
|
||||||
outcome,
|
outcome,
|
||||||
signal,
|
signal,
|
||||||
|
payload,
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
this.setStatusInternal(
|
this.setStatusInternal(
|
||||||
|
@ -492,6 +498,7 @@ export class CoreToolScheduler {
|
||||||
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
|
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
|
||||||
outcome: ToolConfirmationOutcome,
|
outcome: ToolConfirmationOutcome,
|
||||||
signal: AbortSignal,
|
signal: AbortSignal,
|
||||||
|
payload?: ToolConfirmationPayload,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const toolCall = this.toolCalls.find(
|
const toolCall = this.toolCalls.find(
|
||||||
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
||||||
|
@ -545,11 +552,62 @@ export class CoreToolScheduler {
|
||||||
} as ToolCallConfirmationDetails);
|
} as ToolCallConfirmationDetails);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// If the client provided new content, apply it before scheduling.
|
||||||
|
if (payload?.newContent && toolCall) {
|
||||||
|
await this._applyInlineModify(
|
||||||
|
toolCall as WaitingToolCall,
|
||||||
|
payload,
|
||||||
|
signal,
|
||||||
|
);
|
||||||
|
}
|
||||||
this.setStatusInternal(callId, 'scheduled');
|
this.setStatusInternal(callId, 'scheduled');
|
||||||
}
|
}
|
||||||
this.attemptExecutionOfScheduledCalls(signal);
|
this.attemptExecutionOfScheduledCalls(signal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies user-provided content changes to a tool call that is awaiting confirmation.
|
||||||
|
* This method updates the tool's arguments and refreshes the confirmation prompt with a new diff
|
||||||
|
* before the tool is scheduled for execution.
|
||||||
|
* @private
|
||||||
|
*/
|
||||||
|
private async _applyInlineModify(
|
||||||
|
toolCall: WaitingToolCall,
|
||||||
|
payload: ToolConfirmationPayload,
|
||||||
|
signal: AbortSignal,
|
||||||
|
): Promise<void> {
|
||||||
|
if (
|
||||||
|
toolCall.confirmationDetails.type !== 'edit' ||
|
||||||
|
!isModifiableTool(toolCall.tool)
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const modifyContext = toolCall.tool.getModifyContext(signal);
|
||||||
|
const currentContent = await modifyContext.getCurrentContent(
|
||||||
|
toolCall.request.args,
|
||||||
|
);
|
||||||
|
|
||||||
|
const updatedParams = modifyContext.createUpdatedParams(
|
||||||
|
currentContent,
|
||||||
|
payload.newContent,
|
||||||
|
toolCall.request.args,
|
||||||
|
);
|
||||||
|
const updatedDiff = Diff.createPatch(
|
||||||
|
modifyContext.getFilePath(toolCall.request.args),
|
||||||
|
currentContent,
|
||||||
|
payload.newContent,
|
||||||
|
'Current',
|
||||||
|
'Proposed',
|
||||||
|
);
|
||||||
|
|
||||||
|
this.setArgsInternal(toolCall.request.callId, updatedParams);
|
||||||
|
this.setStatusInternal(toolCall.request.callId, 'awaiting_approval', {
|
||||||
|
...toolCall.confirmationDetails,
|
||||||
|
fileDiff: updatedDiff,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
|
private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
|
||||||
const allCallsFinalOrScheduled = this.toolCalls.every(
|
const allCallsFinalOrScheduled = this.toolCalls.every(
|
||||||
(call) =>
|
(call) =>
|
||||||
|
|
|
@ -199,12 +199,21 @@ export interface FileDiff {
|
||||||
export interface ToolEditConfirmationDetails {
|
export interface ToolEditConfirmationDetails {
|
||||||
type: 'edit';
|
type: 'edit';
|
||||||
title: string;
|
title: string;
|
||||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
onConfirm: (
|
||||||
|
outcome: ToolConfirmationOutcome,
|
||||||
|
payload?: ToolConfirmationPayload,
|
||||||
|
) => Promise<void>;
|
||||||
fileName: string;
|
fileName: string;
|
||||||
fileDiff: string;
|
fileDiff: string;
|
||||||
isModifying?: boolean;
|
isModifying?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ToolConfirmationPayload {
|
||||||
|
// used to override `modifiedProposedContent` for modifiable tools in the
|
||||||
|
// inline modify flow
|
||||||
|
newContent: string;
|
||||||
|
}
|
||||||
|
|
||||||
export interface ToolExecuteConfirmationDetails {
|
export interface ToolExecuteConfirmationDetails {
|
||||||
type: 'exec';
|
type: 'exec';
|
||||||
title: string;
|
title: string;
|
||||||
|
|
Loading…
Reference in New Issue