feat: Handle inline content modification in tool scheduler (#2883)

This commit is contained in:
Adam Weidman 2025-07-05 23:19:41 +00:00 committed by GitHub
parent 2b8a565f89
commit 9211905ff1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 183 additions and 2 deletions

View File

@ -16,11 +16,14 @@ import {
BaseTool, BaseTool,
ToolCallConfirmationDetails, ToolCallConfirmationDetails,
ToolConfirmationOutcome, ToolConfirmationOutcome,
ToolConfirmationPayload,
ToolResult, ToolResult,
Config, Config,
} from '../index.js'; } from '../index.js';
import { Part, PartListUnion } from '@google/genai'; import { Part, PartListUnion } from '@google/genai';
import { ModifiableTool, ModifyContext } from '../tools/modifiable-tool.js';
class MockTool extends BaseTool<Record<string, unknown>, ToolResult> { class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
shouldConfirm = false; shouldConfirm = false;
executeFn = vi.fn(); executeFn = vi.fn();
@ -54,6 +57,47 @@ class MockTool extends BaseTool<Record<string, unknown>, ToolResult> {
} }
} }
class MockModifiableTool
extends MockTool
implements ModifiableTool<Record<string, unknown>>
{
constructor(name = 'mockModifiableTool') {
super(name);
this.shouldConfirm = true;
}
getModifyContext(
_abortSignal: AbortSignal,
): ModifyContext<Record<string, unknown>> {
return {
getFilePath: () => 'test.txt',
getCurrentContent: async () => 'old content',
getProposedContent: async () => 'new content',
createUpdatedParams: (
_oldContent: string,
modifiedProposedContent: string,
_originalParams: Record<string, unknown>,
) => ({ newContent: modifiedProposedContent }),
};
}
async shouldConfirmExecute(
_params: Record<string, unknown>,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.shouldConfirm) {
return {
type: 'edit',
title: 'Confirm Mock Tool',
fileName: 'test.txt',
fileDiff: 'diff',
onConfirm: async () => {},
};
}
return false;
}
}
describe('CoreToolScheduler', () => { describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool(); const mockTool = new MockTool();
@ -122,6 +166,76 @@ describe('CoreToolScheduler', () => {
}); });
}); });
describe('CoreToolScheduler with payload', () => {
it('should update args and diff and execute tool when payload is provided', async () => {
const mockTool = new MockModifiableTool();
const toolRegistry = {
getTool: () => mockTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {} as any,
registerTool: () => {},
getToolByName: () => mockTool,
getToolByDisplayName: () => mockTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
};
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
const mockConfig = {
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
toolRegistry: Promise.resolve(toolRegistry as any),
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
});
const abortController = new AbortController();
const request = {
callId: '1',
name: 'mockModifiableTool',
args: {},
isClientInitiated: false,
};
await scheduler.schedule([request], abortController.signal);
const confirmationDetails = await mockTool.shouldConfirmExecute(
{},
abortController.signal,
);
if (confirmationDetails) {
const payload: ToolConfirmationPayload = { newContent: 'final version' };
await scheduler.handleConfirmationResponse(
'1',
confirmationDetails.onConfirm,
ToolConfirmationOutcome.ProceedOnce,
abortController.signal,
payload,
);
}
expect(onAllToolCallsComplete).toHaveBeenCalled();
const completedCalls = onAllToolCallsComplete.mock
.calls[0][0] as ToolCall[];
expect(completedCalls[0].status).toBe('success');
expect(mockTool.executeFn).toHaveBeenCalledWith({
newContent: 'final version',
});
});
});
describe('convertToFunctionResponse', () => { describe('convertToFunctionResponse', () => {
const toolName = 'testTool'; const toolName = 'testTool';
const callId = 'call1'; const callId = 'call1';

View File

@ -17,6 +17,7 @@ import {
Config, Config,
logToolCall, logToolCall,
ToolCallEvent, ToolCallEvent,
ToolConfirmationPayload,
} from '../index.js'; } from '../index.js';
import { Part, PartListUnion } from '@google/genai'; import { Part, PartListUnion } from '@google/genai';
import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js'; import { getResponseTextFromParts } from '../utils/generateContentResponseUtilities.js';
@ -25,6 +26,7 @@ import {
ModifyContext, ModifyContext,
modifyWithEditor, modifyWithEditor,
} from '../tools/modifiable-tool.js'; } from '../tools/modifiable-tool.js';
import * as Diff from 'diff';
export type ValidatingToolCall = { export type ValidatingToolCall = {
status: 'validating'; status: 'validating';
@ -455,12 +457,16 @@ export class CoreToolScheduler {
const originalOnConfirm = confirmationDetails.onConfirm; const originalOnConfirm = confirmationDetails.onConfirm;
const wrappedConfirmationDetails: ToolCallConfirmationDetails = { const wrappedConfirmationDetails: ToolCallConfirmationDetails = {
...confirmationDetails, ...confirmationDetails,
onConfirm: (outcome: ToolConfirmationOutcome) => onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) =>
this.handleConfirmationResponse( this.handleConfirmationResponse(
reqInfo.callId, reqInfo.callId,
originalOnConfirm, originalOnConfirm,
outcome, outcome,
signal, signal,
payload,
), ),
}; };
this.setStatusInternal( this.setStatusInternal(
@ -492,6 +498,7 @@ export class CoreToolScheduler {
originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>, originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>,
outcome: ToolConfirmationOutcome, outcome: ToolConfirmationOutcome,
signal: AbortSignal, signal: AbortSignal,
payload?: ToolConfirmationPayload,
): Promise<void> { ): Promise<void> {
const toolCall = this.toolCalls.find( const toolCall = this.toolCalls.find(
(c) => c.request.callId === callId && c.status === 'awaiting_approval', (c) => c.request.callId === callId && c.status === 'awaiting_approval',
@ -545,11 +552,62 @@ export class CoreToolScheduler {
} as ToolCallConfirmationDetails); } as ToolCallConfirmationDetails);
} }
} else { } else {
// If the client provided new content, apply it before scheduling.
if (payload?.newContent && toolCall) {
await this._applyInlineModify(
toolCall as WaitingToolCall,
payload,
signal,
);
}
this.setStatusInternal(callId, 'scheduled'); this.setStatusInternal(callId, 'scheduled');
} }
this.attemptExecutionOfScheduledCalls(signal); this.attemptExecutionOfScheduledCalls(signal);
} }
/**
* Applies user-provided content changes to a tool call that is awaiting confirmation.
* This method updates the tool's arguments and refreshes the confirmation prompt with a new diff
* before the tool is scheduled for execution.
* @private
*/
private async _applyInlineModify(
toolCall: WaitingToolCall,
payload: ToolConfirmationPayload,
signal: AbortSignal,
): Promise<void> {
if (
toolCall.confirmationDetails.type !== 'edit' ||
!isModifiableTool(toolCall.tool)
) {
return;
}
const modifyContext = toolCall.tool.getModifyContext(signal);
const currentContent = await modifyContext.getCurrentContent(
toolCall.request.args,
);
const updatedParams = modifyContext.createUpdatedParams(
currentContent,
payload.newContent,
toolCall.request.args,
);
const updatedDiff = Diff.createPatch(
modifyContext.getFilePath(toolCall.request.args),
currentContent,
payload.newContent,
'Current',
'Proposed',
);
this.setArgsInternal(toolCall.request.callId, updatedParams);
this.setStatusInternal(toolCall.request.callId, 'awaiting_approval', {
...toolCall.confirmationDetails,
fileDiff: updatedDiff,
});
}
private attemptExecutionOfScheduledCalls(signal: AbortSignal): void { private attemptExecutionOfScheduledCalls(signal: AbortSignal): void {
const allCallsFinalOrScheduled = this.toolCalls.every( const allCallsFinalOrScheduled = this.toolCalls.every(
(call) => (call) =>

View File

@ -199,12 +199,21 @@ export interface FileDiff {
export interface ToolEditConfirmationDetails { export interface ToolEditConfirmationDetails {
type: 'edit'; type: 'edit';
title: string; title: string;
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>; onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
) => Promise<void>;
fileName: string; fileName: string;
fileDiff: string; fileDiff: string;
isModifying?: boolean; isModifying?: boolean;
} }
export interface ToolConfirmationPayload {
// used to override `modifiedProposedContent` for modifiable tools in the
// inline modify flow
newContent: string;
}
export interface ToolExecuteConfirmationDetails { export interface ToolExecuteConfirmationDetails {
type: 'exec'; type: 'exec';
title: string; title: string;