feat(core): Continue declarative tool migration. (#6114)

This commit is contained in:
joshualitt 2025-08-13 11:57:37 -07:00 committed by GitHub
parent 22109db320
commit 904f4623b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 623 additions and 697 deletions

View File

@ -74,9 +74,11 @@ describe('LSTool', () => {
const params = {
path: '/home/user/project/src',
};
const error = lsTool.validateToolParams(params);
expect(error).toBeNull();
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
expect(invocation).toBeDefined();
});
it('should reject relative paths', () => {
@ -84,8 +86,9 @@ describe('LSTool', () => {
path: './src',
};
const error = lsTool.validateToolParams(params);
expect(error).toBe('Path must be absolute: ./src');
expect(() => lsTool.build(params)).toThrow(
'Path must be absolute: ./src',
);
});
it('should reject paths outside workspace with clear error message', () => {
@ -93,8 +96,7 @@ describe('LSTool', () => {
path: '/etc/passwd',
};
const error = lsTool.validateToolParams(params);
expect(error).toBe(
expect(() => lsTool.build(params)).toThrow(
'Path must be within one of the workspace directories: /home/user/project, /home/user/other-project',
);
});
@ -103,9 +105,11 @@ describe('LSTool', () => {
const params = {
path: '/home/user/other-project/lib',
};
const error = lsTool.validateToolParams(params);
expect(error).toBeNull();
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
expect(invocation).toBeDefined();
});
});
@ -133,10 +137,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('[DIR] subdir');
expect(result.llmContent).toContain('file1.ts');
@ -161,10 +163,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('module1.js');
expect(result.llmContent).toContain('module2.js');
@ -179,10 +179,8 @@ describe('LSTool', () => {
} as fs.Stats);
vi.mocked(fs.readdirSync).mockReturnValue([]);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toBe(
'Directory /home/user/project/empty is empty.',
@ -207,10 +205,11 @@ describe('LSTool', () => {
});
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath, ignore: ['*.spec.js'] },
new AbortController().signal,
);
const invocation = lsTool.build({
path: testPath,
ignore: ['*.spec.js'],
});
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('test.js');
expect(result.llmContent).toContain('index.js');
@ -238,10 +237,8 @@ describe('LSTool', () => {
(path: string) => path.includes('ignored.js'),
);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('file1.js');
expect(result.llmContent).toContain('file2.js');
@ -269,10 +266,8 @@ describe('LSTool', () => {
(path: string) => path.includes('private.js'),
);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('file1.js');
expect(result.llmContent).toContain('file2.js');
@ -287,10 +282,8 @@ describe('LSTool', () => {
isDirectory: () => false,
} as fs.Stats);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Path is not a directory');
expect(result.returnDisplay).toBe('Error: Path is not a directory.');
@ -303,10 +296,8 @@ describe('LSTool', () => {
throw new Error('ENOENT: no such file or directory');
});
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Error listing directory');
expect(result.returnDisplay).toBe('Error: Failed to list directory.');
@ -336,10 +327,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
const lines = (
typeof result.llmContent === 'string' ? result.llmContent : ''
@ -361,24 +350,18 @@ describe('LSTool', () => {
throw new Error('EACCES: permission denied');
});
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('Error listing directory');
expect(result.llmContent).toContain('permission denied');
expect(result.returnDisplay).toBe('Error: Failed to list directory.');
});
it('should validate parameters and return error for invalid params', async () => {
const result = await lsTool.execute(
{ path: '../outside' },
new AbortController().signal,
it('should throw for invalid params at build time', async () => {
expect(() => lsTool.build({ path: '../outside' })).toThrow(
'Path must be absolute: ../outside',
);
expect(result.llmContent).toContain('Invalid parameters provided');
expect(result.returnDisplay).toBe('Error: Failed to execute tool.');
});
it('should handle errors accessing individual files during listing', async () => {
@ -406,10 +389,8 @@ describe('LSTool', () => {
.spyOn(console, 'error')
.mockImplementation(() => {});
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
// Should still list the accessible file
expect(result.llmContent).toContain('accessible.ts');
@ -428,19 +409,25 @@ describe('LSTool', () => {
describe('getDescription', () => {
it('should return shortened relative path', () => {
const params = {
path: path.join(mockPrimaryDir, 'deeply', 'nested', 'directory'),
path: `${mockPrimaryDir}/deeply/nested/directory`,
};
const description = lsTool.getDescription(params);
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
const description = invocation.getDescription();
expect(description).toBe(path.join('deeply', 'nested', 'directory'));
});
it('should handle paths in secondary workspace', () => {
const params = {
path: path.join(mockSecondaryDir, 'lib'),
path: `${mockSecondaryDir}/lib`,
};
const description = lsTool.getDescription(params);
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
const invocation = lsTool.build(params);
const description = invocation.getDescription();
expect(description).toBe(path.join('..', 'other-project', 'lib'));
});
});
@ -448,22 +435,25 @@ describe('LSTool', () => {
describe('workspace boundary validation', () => {
it('should accept paths in primary workspace directory', () => {
const params = { path: `${mockPrimaryDir}/src` };
expect(lsTool.validateToolParams(params)).toBeNull();
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
expect(lsTool.build(params)).toBeDefined();
});
it('should accept paths in secondary workspace directory', () => {
const params = { path: `${mockSecondaryDir}/lib` };
expect(lsTool.validateToolParams(params)).toBeNull();
vi.mocked(fs.statSync).mockReturnValue({
isDirectory: () => true,
} as fs.Stats);
expect(lsTool.build(params)).toBeDefined();
});
it('should reject paths outside all workspace directories', () => {
const params = { path: '/etc/passwd' };
const error = lsTool.validateToolParams(params);
expect(error).toContain(
expect(() => lsTool.build(params)).toThrow(
'Path must be within one of the workspace directories',
);
expect(error).toContain(mockPrimaryDir);
expect(error).toContain(mockSecondaryDir);
});
it('should list files from secondary workspace directory', async () => {
@ -483,10 +473,8 @@ describe('LSTool', () => {
vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any);
const result = await lsTool.execute(
{ path: testPath },
new AbortController().signal,
);
const invocation = lsTool.build({ path: testPath });
const result = await invocation.execute(new AbortController().signal);
expect(result.llmContent).toContain('test1.spec.ts');
expect(result.llmContent).toContain('test2.spec.ts');

View File

@ -6,7 +6,13 @@
import fs from 'fs';
import path from 'path';
import { BaseTool, Kind, ToolResult } from './tools.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolInvocation,
ToolResult,
} from './tools.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js';
@ -64,10 +70,199 @@ export interface FileEntry {
modifiedTime: Date;
}
class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
constructor(
private readonly config: Config,
params: LSToolParams,
) {
super(params);
}
/**
* Checks if a filename matches any of the ignore patterns
* @param filename Filename to check
* @param patterns Array of glob patterns to check against
* @returns True if the filename should be ignored
*/
private shouldIgnore(filename: string, patterns?: string[]): boolean {
if (!patterns || patterns.length === 0) {
return false;
}
for (const pattern of patterns) {
// Convert glob pattern to RegExp
const regexPattern = pattern
.replace(/[.+^${}()|[\]\\]/g, '\\$&')
.replace(/\*/g, '.*')
.replace(/\?/g, '.');
const regex = new RegExp(`^${regexPattern}$`);
if (regex.test(filename)) {
return true;
}
}
return false;
}
/**
* Gets a description of the file reading operation
* @returns A string describing the file being read
*/
getDescription(): string {
const relativePath = makeRelative(
this.params.path,
this.config.getTargetDir(),
);
return shortenPath(relativePath);
}
// Helper for consistent error formatting
private errorResult(llmContent: string, returnDisplay: string): ToolResult {
return {
llmContent,
// Keep returnDisplay simpler in core logic
returnDisplay: `Error: ${returnDisplay}`,
};
}
/**
* Executes the LS operation with the given parameters
* @returns Result of the LS operation
*/
async execute(_signal: AbortSignal): Promise<ToolResult> {
try {
const stats = fs.statSync(this.params.path);
if (!stats) {
// fs.statSync throws on non-existence, so this check might be redundant
// but keeping for clarity. Error message adjusted.
return this.errorResult(
`Error: Directory not found or inaccessible: ${this.params.path}`,
`Directory not found or inaccessible.`,
);
}
if (!stats.isDirectory()) {
return this.errorResult(
`Error: Path is not a directory: ${this.params.path}`,
`Path is not a directory.`,
);
}
const files = fs.readdirSync(this.params.path);
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
this.params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore,
respectGeminiIgnore:
this.params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore,
};
// Get centralized file discovery service
const fileDiscovery = this.config.getFileService();
const entries: FileEntry[] = [];
let gitIgnoredCount = 0;
let geminiIgnoredCount = 0;
if (files.length === 0) {
// Changed error message to be more neutral for LLM
return {
llmContent: `Directory ${this.params.path} is empty.`,
returnDisplay: `Directory is empty.`,
};
}
for (const file of files) {
if (this.shouldIgnore(file, this.params.ignore)) {
continue;
}
const fullPath = path.join(this.params.path, file);
const relativePath = path.relative(
this.config.getTargetDir(),
fullPath,
);
// Check if this file should be ignored based on git or gemini ignore rules
if (
fileFilteringOptions.respectGitIgnore &&
fileDiscovery.shouldGitIgnoreFile(relativePath)
) {
gitIgnoredCount++;
continue;
}
if (
fileFilteringOptions.respectGeminiIgnore &&
fileDiscovery.shouldGeminiIgnoreFile(relativePath)
) {
geminiIgnoredCount++;
continue;
}
try {
const stats = fs.statSync(fullPath);
const isDir = stats.isDirectory();
entries.push({
name: file,
path: fullPath,
isDirectory: isDir,
size: isDir ? 0 : stats.size,
modifiedTime: stats.mtime,
});
} catch (error) {
// Log error internally but don't fail the whole listing
console.error(`Error accessing ${fullPath}: ${error}`);
}
}
// Sort entries (directories first, then alphabetically)
entries.sort((a, b) => {
if (a.isDirectory && !b.isDirectory) return -1;
if (!a.isDirectory && b.isDirectory) return 1;
return a.name.localeCompare(b.name);
});
// Create formatted content for LLM
const directoryContent = entries
.map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`)
.join('\n');
let resultMessage = `Directory listing for ${this.params.path}:\n${directoryContent}`;
const ignoredMessages = [];
if (gitIgnoredCount > 0) {
ignoredMessages.push(`${gitIgnoredCount} git-ignored`);
}
if (geminiIgnoredCount > 0) {
ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`);
}
if (ignoredMessages.length > 0) {
resultMessage += `\n\n(${ignoredMessages.join(', ')})`;
}
let displayMessage = `Listed ${entries.length} item(s).`;
if (ignoredMessages.length > 0) {
displayMessage += ` (${ignoredMessages.join(', ')})`;
}
return {
llmContent: resultMessage,
returnDisplay: displayMessage,
};
} catch (error) {
const errorMsg = `Error listing directory: ${error instanceof Error ? error.message : String(error)}`;
return this.errorResult(errorMsg, 'Failed to list directory.');
}
}
}
/**
* Implementation of the LS tool logic
*/
export class LSTool extends BaseTool<LSToolParams, ToolResult> {
export class LSTool extends BaseDeclarativeTool<LSToolParams, ToolResult> {
static readonly Name = 'list_directory';
constructor(private config: Config) {
@ -134,198 +329,16 @@ export class LSTool extends BaseTool<LSToolParams, ToolResult> {
const workspaceContext = this.config.getWorkspaceContext();
if (!workspaceContext.isPathWithinWorkspace(params.path)) {
const directories = workspaceContext.getDirectories();
return `Path must be within one of the workspace directories: ${directories.join(', ')}`;
return `Path must be within one of the workspace directories: ${directories.join(
', ',
)}`;
}
return null;
}
/**
* Checks if a filename matches any of the ignore patterns
* @param filename Filename to check
* @param patterns Array of glob patterns to check against
* @returns True if the filename should be ignored
*/
private shouldIgnore(filename: string, patterns?: string[]): boolean {
if (!patterns || patterns.length === 0) {
return false;
}
for (const pattern of patterns) {
// Convert glob pattern to RegExp
const regexPattern = pattern
.replace(/[.+^${}()|[\]\\]/g, '\\$&')
.replace(/\*/g, '.*')
.replace(/\?/g, '.');
const regex = new RegExp(`^${regexPattern}$`);
if (regex.test(filename)) {
return true;
}
}
return false;
}
/**
* Gets a description of the file reading operation
* @param params Parameters for the file reading
* @returns A string describing the file being read
*/
getDescription(params: LSToolParams): string {
const relativePath = makeRelative(params.path, this.config.getTargetDir());
return shortenPath(relativePath);
}
// Helper for consistent error formatting
private errorResult(llmContent: string, returnDisplay: string): ToolResult {
return {
llmContent,
// Keep returnDisplay simpler in core logic
returnDisplay: `Error: ${returnDisplay}`,
};
}
/**
* Executes the LS operation with the given parameters
* @param params Parameters for the LS operation
* @returns Result of the LS operation
*/
async execute(
protected createInvocation(
params: LSToolParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const validationError = this.validateToolParams(params);
if (validationError) {
return this.errorResult(
`Error: Invalid parameters provided. Reason: ${validationError}`,
`Failed to execute tool.`,
);
}
try {
const stats = fs.statSync(params.path);
if (!stats) {
// fs.statSync throws on non-existence, so this check might be redundant
// but keeping for clarity. Error message adjusted.
return this.errorResult(
`Error: Directory not found or inaccessible: ${params.path}`,
`Directory not found or inaccessible.`,
);
}
if (!stats.isDirectory()) {
return this.errorResult(
`Error: Path is not a directory: ${params.path}`,
`Path is not a directory.`,
);
}
const files = fs.readdirSync(params.path);
const defaultFileIgnores =
this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS;
const fileFilteringOptions = {
respectGitIgnore:
params.file_filtering_options?.respect_git_ignore ??
defaultFileIgnores.respectGitIgnore,
respectGeminiIgnore:
params.file_filtering_options?.respect_gemini_ignore ??
defaultFileIgnores.respectGeminiIgnore,
};
// Get centralized file discovery service
const fileDiscovery = this.config.getFileService();
const entries: FileEntry[] = [];
let gitIgnoredCount = 0;
let geminiIgnoredCount = 0;
if (files.length === 0) {
// Changed error message to be more neutral for LLM
return {
llmContent: `Directory ${params.path} is empty.`,
returnDisplay: `Directory is empty.`,
};
}
for (const file of files) {
if (this.shouldIgnore(file, params.ignore)) {
continue;
}
const fullPath = path.join(params.path, file);
const relativePath = path.relative(
this.config.getTargetDir(),
fullPath,
);
// Check if this file should be ignored based on git or gemini ignore rules
if (
fileFilteringOptions.respectGitIgnore &&
fileDiscovery.shouldGitIgnoreFile(relativePath)
) {
gitIgnoredCount++;
continue;
}
if (
fileFilteringOptions.respectGeminiIgnore &&
fileDiscovery.shouldGeminiIgnoreFile(relativePath)
) {
geminiIgnoredCount++;
continue;
}
try {
const stats = fs.statSync(fullPath);
const isDir = stats.isDirectory();
entries.push({
name: file,
path: fullPath,
isDirectory: isDir,
size: isDir ? 0 : stats.size,
modifiedTime: stats.mtime,
});
} catch (error) {
// Log error internally but don't fail the whole listing
console.error(`Error accessing ${fullPath}: ${error}`);
}
}
// Sort entries (directories first, then alphabetically)
entries.sort((a, b) => {
if (a.isDirectory && !b.isDirectory) return -1;
if (!a.isDirectory && b.isDirectory) return 1;
return a.name.localeCompare(b.name);
});
// Create formatted content for LLM
const directoryContent = entries
.map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`)
.join('\n');
let resultMessage = `Directory listing for ${params.path}:\n${directoryContent}`;
const ignoredMessages = [];
if (gitIgnoredCount > 0) {
ignoredMessages.push(`${gitIgnoredCount} git-ignored`);
}
if (geminiIgnoredCount > 0) {
ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`);
}
if (ignoredMessages.length > 0) {
resultMessage += `\n\n(${ignoredMessages.join(', ')})`;
}
let displayMessage = `Listed ${entries.length} item(s).`;
if (ignoredMessages.length > 0) {
displayMessage += ` (${ignoredMessages.join(', ')})`;
}
return {
llmContent: resultMessage,
returnDisplay: displayMessage,
};
} catch (error) {
const errorMsg = `Error listing directory: ${error instanceof Error ? error.message : String(error)}`;
return this.errorResult(errorMsg, 'Failed to list directory.');
}
): ToolInvocation<LSToolParams, ToolResult> {
return new LSToolInvocation(this.config, params);
}
}

View File

@ -73,11 +73,21 @@ describe('DiscoveredMCPTool', () => {
required: ['param'],
};
let tool: DiscoveredMCPTool;
beforeEach(() => {
mockCallTool.mockClear();
mockToolMethod.mockClear();
tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
// Clear allowlist before each relevant test, especially for shouldConfirmExecute
(DiscoveredMCPTool as any).allowlist.clear();
const invocation = tool.build({}) as any;
invocation.constructor.allowlist.clear();
});
afterEach(() => {
@ -86,14 +96,6 @@ describe('DiscoveredMCPTool', () => {
describe('constructor', () => {
it('should set properties correctly', () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
expect(tool.name).toBe(serverToolName);
expect(tool.schema.name).toBe(serverToolName);
expect(tool.schema.description).toBe(baseDescription);
@ -105,7 +107,7 @@ describe('DiscoveredMCPTool', () => {
it('should accept and store a custom timeout', () => {
const customTimeout = 5000;
const tool = new DiscoveredMCPTool(
const toolWithTimeout = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
@ -113,19 +115,12 @@ describe('DiscoveredMCPTool', () => {
inputSchema,
customTimeout,
);
expect(tool.timeout).toBe(customTimeout);
expect(toolWithTimeout.timeout).toBe(customTimeout);
});
});
describe('execute', () => {
it('should call mcpTool.callTool with correct parameters and format display output', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'testValue' };
const mockToolSuccessResultObject = {
success: true,
@ -147,7 +142,10 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(mockMcpToolResponseParts);
const toolResult: ToolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult: ToolResult = await invocation.execute(
new AbortController().signal,
);
expect(mockCallTool).toHaveBeenCalledWith([
{ name: serverToolName, args: params },
@ -163,17 +161,13 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle empty result from getStringifiedResultForDisplay', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'testValue' };
const mockMcpToolResponsePartsEmpty: Part[] = [];
mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty);
const toolResult: ToolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult: ToolResult = await invocation.execute(
new AbortController().signal,
);
expect(toolResult.returnDisplay).toBe('```json\n[]\n```');
expect(toolResult.llmContent).toEqual([
{ text: '[Error: Could not parse tool response]' },
@ -181,28 +175,17 @@ describe('DiscoveredMCPTool', () => {
});
it('should propagate rejection if mcpTool.callTool rejects', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { param: 'failCase' };
const expectedError = new Error('MCP call failed');
mockCallTool.mockRejectedValue(expectedError);
await expect(tool.execute(params)).rejects.toThrow(expectedError);
const invocation = tool.build(params);
await expect(
invocation.execute(new AbortController().signal),
).rejects.toThrow(expectedError);
});
it('should handle a simple text response correctly', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { query: 'test' };
const successMessage = 'This is a success message.';
@ -221,7 +204,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
// 1. Assert that the llmContent sent to the scheduler is a clean Part array.
expect(toolResult.llmContent).toEqual([{ text: successMessage }]);
@ -236,13 +220,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an AudioBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'play' };
const sdkResponse: Part[] = [
{
@ -262,7 +239,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@ -279,13 +257,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a ResourceLinkBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@ -306,7 +277,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@ -319,13 +291,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an embedded text ResourceBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@ -348,7 +313,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'This is the text content.' },
@ -357,13 +323,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle an embedded binary ResourceBlock response', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { resource: 'get' };
const sdkResponse: Part[] = [
{
@ -386,7 +345,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{
@ -405,13 +365,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a mix of content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'complex' };
const sdkResponse: Part[] = [
{
@ -433,7 +386,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'First part.' },
@ -454,13 +408,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should ignore unknown content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'test' };
const sdkResponse: Part[] = [
{
@ -477,7 +424,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([{ text: 'Valid part.' }]);
expect(toolResult.returnDisplay).toBe(
@ -486,13 +434,6 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle a complex mix of content block types', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const params = { action: 'super-complex' };
const sdkResponse: Part[] = [
{
@ -527,7 +468,8 @@ describe('DiscoveredMCPTool', () => {
];
mockCallTool.mockResolvedValue(sdkResponse);
const toolResult = await tool.execute(params);
const invocation = tool.build(params);
const toolResult = await invocation.execute(new AbortController().signal);
expect(toolResult.llmContent).toEqual([
{ text: 'Here is a resource.' },
@ -552,10 +494,8 @@ describe('DiscoveredMCPTool', () => {
});
describe('shouldConfirmExecute', () => {
// beforeEach is already clearing allowlist
it('should return false if trust is true', async () => {
const tool = new DiscoveredMCPTool(
const trustedTool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
@ -564,50 +504,32 @@ describe('DiscoveredMCPTool', () => {
undefined,
true,
);
const invocation = trustedTool.build({});
expect(
await tool.shouldConfirmExecute({}, new AbortController().signal),
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return false if server is allowlisted', async () => {
(DiscoveredMCPTool as any).allowlist.add(serverName);
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const invocation = tool.build({}) as any;
invocation.constructor.allowlist.add(serverName);
expect(
await tool.shouldConfirmExecute({}, new AbortController().signal),
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return false if tool is allowlisted', async () => {
const toolAllowlistKey = `${serverName}.${serverToolName}`;
(DiscoveredMCPTool as any).allowlist.add(toolAllowlistKey);
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const invocation = tool.build({}) as any;
invocation.constructor.allowlist.add(toolAllowlistKey);
expect(
await tool.shouldConfirmExecute({}, new AbortController().signal),
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return confirmation details if not trusted and not allowlisted', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({});
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@ -629,15 +551,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should add server to allowlist on ProceedAlwaysServer', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({}) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@ -650,7 +565,7 @@ describe('DiscoveredMCPTool', () => {
await confirmation.onConfirm(
ToolConfirmationOutcome.ProceedAlwaysServer,
);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(true);
expect(invocation.constructor.allowlist.has(serverName)).toBe(true);
} else {
throw new Error(
'Confirmation details or onConfirm not in expected format',
@ -659,16 +574,9 @@ describe('DiscoveredMCPTool', () => {
});
it('should add tool to allowlist on ProceedAlwaysTool', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const toolAllowlistKey = `${serverName}.${serverToolName}`;
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({}) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@ -679,7 +587,7 @@ describe('DiscoveredMCPTool', () => {
typeof confirmation.onConfirm === 'function'
) {
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlwaysTool);
expect((DiscoveredMCPTool as any).allowlist.has(toolAllowlistKey)).toBe(
expect(invocation.constructor.allowlist.has(toolAllowlistKey)).toBe(
true,
);
} else {
@ -690,15 +598,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle Cancel confirmation outcome', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({}) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@ -710,11 +611,9 @@ describe('DiscoveredMCPTool', () => {
) {
// Cancel should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.Cancel);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
false,
);
expect(invocation.constructor.allowlist.has(serverName)).toBe(false);
expect(
(DiscoveredMCPTool as any).allowlist.has(
invocation.constructor.allowlist.has(
`${serverName}.${serverToolName}`,
),
).toBe(false);
@ -726,15 +625,8 @@ describe('DiscoveredMCPTool', () => {
});
it('should handle ProceedOnce confirmation outcome', async () => {
const tool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
);
const confirmation = await tool.shouldConfirmExecute(
{},
const invocation = tool.build({}) as any;
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
@ -746,11 +638,9 @@ describe('DiscoveredMCPTool', () => {
) {
// ProceedOnce should not add anything to allowlist
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce);
expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(
false,
);
expect(invocation.constructor.allowlist.has(serverName)).toBe(false);
expect(
(DiscoveredMCPTool as any).allowlist.has(
invocation.constructor.allowlist.has(
`${serverName}.${serverToolName}`,
),
).toBe(false);

View File

@ -5,14 +5,16 @@
*/
import {
BaseTool,
ToolResult,
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolCallConfirmationDetails,
ToolConfirmationOutcome,
ToolInvocation,
ToolMcpConfirmationDetails,
Kind,
ToolResult,
} from './tools.js';
import { CallableTool, Part, FunctionCall } from '@google/genai';
import { CallableTool, FunctionCall, Part } from '@google/genai';
type ToolParams = Record<string, unknown>;
@ -50,9 +52,84 @@ type McpContentBlock =
| McpResourceBlock
| McpResourceLinkBlock;
export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
class DiscoveredMCPToolInvocation extends BaseToolInvocation<
ToolParams,
ToolResult
> {
private static readonly allowlist: Set<string> = new Set();
constructor(
private readonly mcpTool: CallableTool,
readonly serverName: string,
readonly serverToolName: string,
readonly displayName: string,
readonly timeout?: number,
readonly trust?: boolean,
params: ToolParams = {},
) {
super(params);
}
async shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const serverAllowListKey = this.serverName;
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
if (this.trust) {
return false; // server is trusted, no confirmation needed
}
if (
DiscoveredMCPToolInvocation.allowlist.has(serverAllowListKey) ||
DiscoveredMCPToolInvocation.allowlist.has(toolAllowListKey)
) {
return false; // server and/or tool already allowlisted
}
const confirmationDetails: ToolMcpConfirmationDetails = {
type: 'mcp',
title: 'Confirm MCP Tool Execution',
serverName: this.serverName,
toolName: this.serverToolName, // Display original tool name in confirmation
toolDisplayName: this.displayName, // Display global registry name exposed to model and user
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey);
}
},
};
return confirmationDetails;
}
async execute(): Promise<ToolResult> {
const functionCalls: FunctionCall[] = [
{
name: this.serverToolName,
args: this.params,
},
];
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
const transformedParts = transformMcpContentToParts(rawResponseParts);
return {
llmContent: transformedParts,
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
};
}
getDescription(): string {
return this.displayName;
}
}
export class DiscoveredMCPTool extends BaseDeclarativeTool<
ToolParams,
ToolResult
> {
constructor(
private readonly mcpTool: CallableTool,
readonly serverName: string,
@ -87,56 +164,18 @@ export class DiscoveredMCPTool extends BaseTool<ToolParams, ToolResult> {
);
}
async shouldConfirmExecute(
_params: ToolParams,
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const serverAllowListKey = this.serverName;
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
if (this.trust) {
return false; // server is trusted, no confirmation needed
}
if (
DiscoveredMCPTool.allowlist.has(serverAllowListKey) ||
DiscoveredMCPTool.allowlist.has(toolAllowListKey)
) {
return false; // server and/or tool already allowlisted
}
const confirmationDetails: ToolMcpConfirmationDetails = {
type: 'mcp',
title: 'Confirm MCP Tool Execution',
serverName: this.serverName,
toolName: this.serverToolName, // Display original tool name in confirmation
toolDisplayName: this.name, // Display global registry name exposed to model and user
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) {
DiscoveredMCPTool.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
DiscoveredMCPTool.allowlist.add(toolAllowListKey);
}
},
};
return confirmationDetails;
}
async execute(params: ToolParams): Promise<ToolResult> {
const functionCalls: FunctionCall[] = [
{
name: this.serverToolName,
args: params,
},
];
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
const transformedParts = transformMcpContentToParts(rawResponseParts);
return {
llmContent: transformedParts,
returnDisplay: getStringifiedResultForDisplay(rawResponseParts),
};
protected createInvocation(
params: ToolParams,
): ToolInvocation<ToolParams, ToolResult> {
return new DiscoveredMCPToolInvocation(
this.mcpTool,
this.serverName,
this.serverToolName,
this.displayName,
this.timeout,
this.trust,
params,
);
}
}

View File

@ -218,7 +218,8 @@ describe('MemoryTool', () => {
it('should call performAddMemoryEntry with correct parameters and return success', async () => {
const params = { fact: 'The sky is blue' };
const result = await memoryTool.execute(params, mockAbortSignal);
const invocation = memoryTool.build(params);
const result = await invocation.execute(mockAbortSignal);
// Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test
const expectedFilePath = path.join(
os.homedir(),
@ -247,14 +248,12 @@ describe('MemoryTool', () => {
it('should return an error if fact is empty', async () => {
const params = { fact: ' ' }; // Empty fact
const result = await memoryTool.execute(params, mockAbortSignal);
const errorMessage = 'Parameter "fact" must be a non-empty string.';
expect(performAddMemoryEntrySpy).not.toHaveBeenCalled();
expect(result.llmContent).toBe(
JSON.stringify({ success: false, error: errorMessage }),
expect(memoryTool.validateToolParams(params)).toBe(
'Parameter "fact" must be a non-empty string.',
);
expect(() => memoryTool.build(params)).toThrow(
'Parameter "fact" must be a non-empty string.',
);
expect(result.returnDisplay).toBe(`Error: ${errorMessage}`);
});
it('should handle errors from performAddMemoryEntry', async () => {
@ -264,7 +263,8 @@ describe('MemoryTool', () => {
);
performAddMemoryEntrySpy.mockRejectedValue(underlyingError);
const result = await memoryTool.execute(params, mockAbortSignal);
const invocation = memoryTool.build(params);
const result = await invocation.execute(mockAbortSignal);
expect(result.llmContent).toBe(
JSON.stringify({
@ -284,17 +284,17 @@ describe('MemoryTool', () => {
beforeEach(() => {
memoryTool = new MemoryTool();
// Clear the allowlist before each test
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.clear();
const invocation = memoryTool.build({ fact: 'mock-fact' });
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.clear();
// Mock fs.readFile to return empty string (file doesn't exist)
vi.mocked(fs.readFile).mockResolvedValue('');
});
it('should return confirmation details when memory file is not allowlisted', async () => {
const params = { fact: 'Test fact' };
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@ -321,15 +321,12 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(),
);
const invocation = memoryTool.build(params);
// Add the memory file to the allowlist
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.add(
memoryFilePath,
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.add(memoryFilePath);
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBe(false);
});
@ -342,10 +339,8 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(),
);
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@ -356,9 +351,8 @@ describe('MemoryTool', () => {
// Check that the memory file was added to the allowlist
expect(
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
memoryFilePath,
),
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(invocation.constructor as any).allowlist.has(memoryFilePath),
).toBe(true);
}
});
@ -371,10 +365,8 @@ describe('MemoryTool', () => {
getCurrentGeminiMdFilename(),
);
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);
@ -382,18 +374,12 @@ describe('MemoryTool', () => {
if (result && result.type === 'edit') {
// Simulate the onConfirm callback with different outcomes
await result.onConfirm(ToolConfirmationOutcome.ProceedOnce);
expect(
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
memoryFilePath,
),
).toBe(false);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const allowlist = (invocation.constructor as any).allowlist;
expect(allowlist.has(memoryFilePath)).toBe(false);
await result.onConfirm(ToolConfirmationOutcome.Cancel);
expect(
(MemoryTool as unknown as { allowlist: Set<string> }).allowlist.has(
memoryFilePath,
),
).toBe(false);
expect(allowlist.has(memoryFilePath)).toBe(false);
}
});
@ -405,10 +391,8 @@ describe('MemoryTool', () => {
// Mock fs.readFile to return existing content
vi.mocked(fs.readFile).mockResolvedValue(existingContent);
const result = await memoryTool.shouldConfirmExecute(
params,
mockAbortSignal,
);
const invocation = memoryTool.build(params);
const result = await invocation.shouldConfirmExecute(mockAbortSignal);
expect(result).toBeDefined();
expect(result).not.toBe(false);

View File

@ -5,11 +5,12 @@
*/
import {
BaseTool,
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolResult,
ToolEditConfirmationDetails,
ToolConfirmationOutcome,
ToolResult,
} from './tools.js';
import { FunctionDeclaration } from '@google/genai';
import * as fs from 'fs/promises';
@ -19,6 +20,7 @@ import * as Diff from 'diff';
import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js';
import { tildeifyPath } from '../utils/paths.js';
import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
const memoryToolSchemaData: FunctionDeclaration = {
name: 'save_memory',
@ -110,32 +112,10 @@ function ensureNewlineSeparation(currentContent: string): string {
return '\n\n';
}
export class MemoryTool
extends BaseTool<SaveMemoryParams, ToolResult>
implements ModifiableDeclarativeTool<SaveMemoryParams>
{
private static readonly allowlist: Set<string> = new Set();
static readonly Name: string = memoryToolSchemaData.name!;
constructor() {
super(
MemoryTool.Name,
'Save Memory',
memoryToolDescription,
Kind.Think,
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
);
}
getDescription(_params: SaveMemoryParams): string {
const memoryFilePath = getGlobalMemoryFilePath();
return `in ${tildeifyPath(memoryFilePath)}`;
}
/**
/**
* Reads the current content of the memory file
*/
private async readMemoryFileContent(): Promise<string> {
async function readMemoryFileContent(): Promise<string> {
try {
return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8');
} catch (err) {
@ -143,12 +123,12 @@ export class MemoryTool
if (!(error instanceof Error) || error.code !== 'ENOENT') throw err;
return '';
}
}
}
/**
/**
* Computes the new content that would result from adding a memory entry
*/
private computeNewContent(currentContent: string, fact: string): string {
function computeNewContent(currentContent: string, fact: string): string {
let processedText = fact.trim();
processedText = processedText.replace(/^(-+\s*)+/, '').trim();
const newMemoryItem = `- ${processedText}`;
@ -187,24 +167,31 @@ export class MemoryTool
'\n'
);
}
}
class MemoryToolInvocation extends BaseToolInvocation<
SaveMemoryParams,
ToolResult
> {
private static readonly allowlist: Set<string> = new Set();
getDescription(): string {
const memoryFilePath = getGlobalMemoryFilePath();
return `in ${tildeifyPath(memoryFilePath)}`;
}
async shouldConfirmExecute(
params: SaveMemoryParams,
_abortSignal: AbortSignal,
): Promise<ToolEditConfirmationDetails | false> {
const memoryFilePath = getGlobalMemoryFilePath();
const allowlistKey = memoryFilePath;
if (MemoryTool.allowlist.has(allowlistKey)) {
if (MemoryToolInvocation.allowlist.has(allowlistKey)) {
return false;
}
// Read current content of the memory file
const currentContent = await this.readMemoryFileContent();
// Calculate the new content that will be written to the memory file
const newContent = this.computeNewContent(currentContent, params.fact);
const currentContent = await readMemoryFileContent();
const newContent = computeNewContent(currentContent, this.params.fact);
const fileName = path.basename(memoryFilePath);
const fileDiff = Diff.createPatch(
@ -226,13 +213,107 @@ export class MemoryTool
newContent,
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
MemoryTool.allowlist.add(allowlistKey);
MemoryToolInvocation.allowlist.add(allowlistKey);
}
},
};
return confirmationDetails;
}
async execute(_signal: AbortSignal): Promise<ToolResult> {
const { fact, modified_by_user, modified_content } = this.params;
try {
if (modified_by_user && modified_content !== undefined) {
// User modified the content in external editor, write it directly
await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), {
recursive: true,
});
await fs.writeFile(
getGlobalMemoryFilePath(),
modified_content,
'utf-8',
);
const successMessage = `Okay, I've updated the memory file with your modifications.`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
} else {
// Use the normal memory entry logic
await MemoryTool.performAddMemoryEntry(
fact,
getGlobalMemoryFilePath(),
{
readFile: fs.readFile,
writeFile: fs.writeFile,
mkdir: fs.mkdir,
},
);
const successMessage = `Okay, I've remembered that: "${fact}"`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
}
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(
`[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`,
);
return {
llmContent: JSON.stringify({
success: false,
error: `Failed to save memory. Detail: ${errorMessage}`,
}),
returnDisplay: `Error saving memory: ${errorMessage}`,
};
}
}
}
export class MemoryTool
extends BaseDeclarativeTool<SaveMemoryParams, ToolResult>
implements ModifiableDeclarativeTool<SaveMemoryParams>
{
static readonly Name: string = memoryToolSchemaData.name!;
constructor() {
super(
MemoryTool.Name,
'Save Memory',
memoryToolDescription,
Kind.Think,
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
);
}
validateToolParams(params: SaveMemoryParams): string | null {
const errors = SchemaValidator.validate(
this.schema.parametersJsonSchema,
params,
);
if (errors) {
return errors;
}
if (params.fact.trim() === '') {
return 'Parameter "fact" must be a non-empty string.';
}
return null;
}
protected createInvocation(params: SaveMemoryParams) {
return new MemoryToolInvocation(params);
}
static async performAddMemoryEntry(
text: string,
memoryFilePath: string,
@ -303,83 +384,14 @@ export class MemoryTool
}
}
async execute(
params: SaveMemoryParams,
_signal: AbortSignal,
): Promise<ToolResult> {
const { fact, modified_by_user, modified_content } = params;
if (!fact || typeof fact !== 'string' || fact.trim() === '') {
const errorMessage = 'Parameter "fact" must be a non-empty string.';
return {
llmContent: JSON.stringify({ success: false, error: errorMessage }),
returnDisplay: `Error: ${errorMessage}`,
};
}
try {
if (modified_by_user && modified_content !== undefined) {
// User modified the content in external editor, write it directly
await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), {
recursive: true,
});
await fs.writeFile(
getGlobalMemoryFilePath(),
modified_content,
'utf-8',
);
const successMessage = `Okay, I've updated the memory file with your modifications.`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
} else {
// Use the normal memory entry logic
await MemoryTool.performAddMemoryEntry(
fact,
getGlobalMemoryFilePath(),
{
readFile: fs.readFile,
writeFile: fs.writeFile,
mkdir: fs.mkdir,
},
);
const successMessage = `Okay, I've remembered that: "${fact}"`;
return {
llmContent: JSON.stringify({
success: true,
message: successMessage,
}),
returnDisplay: successMessage,
};
}
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : String(error);
console.error(
`[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`,
);
return {
llmContent: JSON.stringify({
success: false,
error: `Failed to save memory. Detail: ${errorMessage}`,
}),
returnDisplay: `Error saving memory: ${errorMessage}`,
};
}
}
getModifyContext(_abortSignal: AbortSignal): ModifyContext<SaveMemoryParams> {
return {
getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(),
getCurrentContent: async (_params: SaveMemoryParams): Promise<string> =>
this.readMemoryFileContent(),
readMemoryFileContent(),
getProposedContent: async (params: SaveMemoryParams): Promise<string> => {
const currentContent = await this.readMemoryFileContent();
return this.computeNewContent(currentContent, params.fact);
const currentContent = await readMemoryFileContent();
return computeNewContent(currentContent, params.fact);
},
createUpdatedParams: (
_oldContent: string,