feat: Handle inline content modification in tool scheduler (#2883)
This commit is contained in:
parent
2b8a565f89
commit
9211905ff1
|
@ -16,11 +16,14 @@ import {
|
|||
BaseTool,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
ToolConfirmationPayload,
|
||||
ToolResult,
|
||||
Config,
|
||||
} from '../index.js';
|
||||
import { Part, PartListUnion } from '@google/genai';
|
||||
|
||||
import { ModifiableTool, ModifyContext } from '../tools/modifiable-tool.js';
|
||||
|
||||
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
|
||||
shouldConfirm = false;
|
||||
executeFn = vi.fn();
|
||||
|
@ -54,6 +57,47 @@ class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
|
|||
}
|
||||
}
|
||||
|
||||
class MockModifiableTool
|
||||
extends MockTool
|
||||
implements ModifiableTool<Record<string, unknown>>
|
||||
{
|
||||
constructor(name = 'mockModifiableTool') {
|
||||
super(name);
|
||||
this.shouldConfirm = true;
|
||||
}
|
||||
|
||||
getModifyContext(
|
||||
_abortSignal: AbortSignal,
|
||||
): ModifyContext<Record<string, unknown>> {
|
||||
return {
|
||||
getFilePath: () => 'test.txt',
|
||||
getCurrentContent: async () => 'old content',
|
||||
getProposedContent: async () => 'new content',
|
||||
createUpdatedParams: (
|
||||
_oldContent: string,
|
||||
modifiedProposedContent: string,
|
||||
_originalParams: Record<string, unknown>,
|
||||
) => ({ newContent: modifiedProposedContent }),
|
||||
};
|
||||
}
|
||||
|
||||
async shouldConfirmExecute(
|
||||
_params: Record<string, unknown>,
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.shouldConfirm) {
|
||||
return {
|
||||
type: 'edit',
|
||||
title: 'Confirm Mock Tool',
|
||||
fileName: 'test.txt',
|
||||
fileDiff: 'diff',
|
||||
onConfirm: async () => {},
|
||||
};
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
describe('CoreToolScheduler', () => {
|
||||
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
|
||||
const mockTool = new MockTool();
|
||||
|
@ -122,6 +166,76 @@ describe('CoreToolScheduler', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('CoreToolScheduler with payload', () => {
|
||||
it('should update args and diff and execute tool when payload is provided', async () => {
|
||||
const mockTool = new MockModifiableTool();
|
||||
const toolRegistry = {
|
||||
getTool: () => mockTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {} as any,
|
||||
registerTool: () => {},
|
||||
getToolByName: () => mockTool,
|
||||
getToolByDisplayName: () => mockTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
};
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
toolRegistry: Promise.resolve(toolRegistry as any),
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const request = {
|
||||
callId: '1',
|
||||
name: 'mockModifiableTool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
};
|
||||
|
||||
await scheduler.schedule([request], abortController.signal);
|
||||
|
||||
const confirmationDetails = await mockTool.shouldConfirmExecute(
|
||||
{},
|
||||
abortController.signal,
|
||||
);
|
||||
|
||||
if (confirmationDetails) {
|
||||
const payload: ToolConfirmationPayload = { newContent: 'final version' };
|
||||
await scheduler.handleConfirmationResponse(
|
||||
'1',
|
||||
confirmationDetails.onConfirm,
|
||||
ToolConfirmationOutcome.ProceedOnce,
|
||||
abortController.signal,
|
||||
payload,
|
||||
);
|
||||
}
|
||||
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
expect(completedCalls[0].status).toBe('success');
|
||||
expect(mockTool.executeFn).toHaveBeenCalledWith({
|
||||
newContent: 'final version',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('convertToFunctionResponse', () => {
|
||||
const toolName = 'testTool';
|
||||
const callId = 'call1';
|
||||
|
|
|
@ -17,6 +17,7 @@ import {
|
|||
Config,
|
||||
logToolCall,
|
||||
ToolCallEvent,
|
||||
ToolConfirmationPayload,
|
||||
} from '../index.js';
|
||||
import { Part, PartListUnion } from '@google/genai';
|
||||
import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js';
|
||||
|
@ -25,6 +26,7 @@ import {
|
|||
ModifyContext,
|
||||
modifyWithEditor,
|
||||
} from '../tools/modifiable-tool.js';
|
||||
import * as Diff from 'diff';
|
||||
|
||||
export type ValidatingToolCall = {
|
||||
status: 'validating';
|
||||
|
@ -455,12 +457,16 @@ export class CoreToolScheduler {
|
|||
const originalOnConfirm = confirmationDetails.onConfirm;
|
||||
const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
|
||||
...confirmationDetails,
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) =>
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
) =>
|
||||
this.handleConfirmationResponse(
|
||||
reqInfo.callId,
|
||||
originalOnConfirm,
|
||||
outcome,
|
||||
signal,
|
||||
payload,
|
||||
),
|
||||
};
|
||||
this.setStatusInternal(
|
||||
|
@ -492,6 +498,7 @@ export class CoreToolScheduler {
|
|||
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
|
||||
outcome: ToolConfirmationOutcome,
|
||||
signal: AbortSignal,
|
||||
payload?: ToolConfirmationPayload,
|
||||
): Promise<void> {
|
||||
const toolCall = this.toolCalls.find(
|
||||
(c) => c.request.callId === callId && c.status === 'awaiting_approval',
|
||||
|
@ -545,11 +552,62 @@ export class CoreToolScheduler {
|
|||
} as ToolCallConfirmationDetails);
|
||||
}
|
||||
} else {
|
||||
// If the client provided new content, apply it before scheduling.
|
||||
if (payload?.newContent && toolCall) {
|
||||
await this._applyInlineModify(
|
||||
toolCall as WaitingToolCall,
|
||||
payload,
|
||||
signal,
|
||||
);
|
||||
}
|
||||
this.setStatusInternal(callId, 'scheduled');
|
||||
}
|
||||
this.attemptExecutionOfScheduledCalls(signal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies user-provided content changes to a tool call that is awaiting confirmation.
|
||||
* This method updates the tool's arguments and refreshes the confirmation prompt with a new diff
|
||||
* before the tool is scheduled for execution.
|
||||
* @private
|
||||
*/
|
||||
private async _applyInlineModify(
|
||||
toolCall: WaitingToolCall,
|
||||
payload: ToolConfirmationPayload,
|
||||
signal: AbortSignal,
|
||||
): Promise<void> {
|
||||
if (
|
||||
toolCall.confirmationDetails.type !== 'edit' ||
|
||||
!isModifiableTool(toolCall.tool)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const modifyContext = toolCall.tool.getModifyContext(signal);
|
||||
const currentContent = await modifyContext.getCurrentContent(
|
||||
toolCall.request.args,
|
||||
);
|
||||
|
||||
const updatedParams = modifyContext.createUpdatedParams(
|
||||
currentContent,
|
||||
payload.newContent,
|
||||
toolCall.request.args,
|
||||
);
|
||||
const updatedDiff = Diff.createPatch(
|
||||
modifyContext.getFilePath(toolCall.request.args),
|
||||
currentContent,
|
||||
payload.newContent,
|
||||
'Current',
|
||||
'Proposed',
|
||||
);
|
||||
|
||||
this.setArgsInternal(toolCall.request.callId, updatedParams);
|
||||
this.setStatusInternal(toolCall.request.callId, 'awaiting_approval', {
|
||||
...toolCall.confirmationDetails,
|
||||
fileDiff: updatedDiff,
|
||||
});
|
||||
}
|
||||
|
||||
private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
|
||||
const allCallsFinalOrScheduled = this.toolCalls.every(
|
||||
(call) =>
|
||||
|
|
|
@ -199,12 +199,21 @@ export interface FileDiff {
|
|||
export interface ToolEditConfirmationDetails {
|
||||
type: 'edit';
|
||||
title: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
) => Promise<void>;
|
||||
fileName: string;
|
||||
fileDiff: string;
|
||||
isModifying?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolConfirmationPayload {
|
||||
// used to override `modifiedProposedContent` for modifiable tools in the
|
||||
// inline modify flow
|
||||
newContent: string;
|
||||
}
|
||||
|
||||
export interface ToolExecuteConfirmationDetails {
|
||||
type: 'exec';
|
||||
title: string;
|
||||
|
|
Loading…
Reference in New Issue