From 904f4623b6945345d5845649e98f554671b1edfb Mon Sep 17 00:00:00 2001 From: joshualitt Date: Wed, 13 Aug 2025 11:57:37 -0700 Subject: [PATCH] feat(core): Continue declarative tool migration. (#6114) --- packages/core/src/tools/ls.test.ts | 142 ++++---- packages/core/src/tools/ls.ts | 393 +++++++++++---------- packages/core/src/tools/mcp-tool.test.ts | 242 ++++--------- packages/core/src/tools/mcp-tool.ts | 149 +++++--- packages/core/src/tools/memoryTool.test.ts | 76 ++-- packages/core/src/tools/memoryTool.ts | 318 +++++++++-------- 6 files changed, 623 insertions(+), 697 deletions(-) diff --git a/packages/core/src/tools/ls.test.ts b/packages/core/src/tools/ls.test.ts index fb99d829..2fbeb37a 100644 --- a/packages/core/src/tools/ls.test.ts +++ b/packages/core/src/tools/ls.test.ts @@ -74,9 +74,11 @@ describe('LSTool', () => { const params = { path: '/home/user/project/src', }; - - const error = lsTool.validateToolParams(params); - expect(error).toBeNull(); + vi.mocked(fs.statSync).mockReturnValue({ + isDirectory: () => true, + } as fs.Stats); + const invocation = lsTool.build(params); + expect(invocation).toBeDefined(); }); it('should reject relative paths', () => { @@ -84,8 +86,9 @@ describe('LSTool', () => { path: './src', }; - const error = lsTool.validateToolParams(params); - expect(error).toBe('Path must be absolute: ./src'); + expect(() => lsTool.build(params)).toThrow( + 'Path must be absolute: ./src', + ); }); it('should reject paths outside workspace with clear error message', () => { @@ -93,8 +96,7 @@ describe('LSTool', () => { path: '/etc/passwd', }; - const error = lsTool.validateToolParams(params); - expect(error).toBe( + expect(() => lsTool.build(params)).toThrow( 'Path must be within one of the workspace directories: /home/user/project, /home/user/other-project', ); }); @@ -103,9 +105,11 @@ describe('LSTool', () => { const params = { path: '/home/user/other-project/lib', }; - - const error = lsTool.validateToolParams(params); - expect(error).toBeNull(); + vi.mocked(fs.statSync).mockReturnValue({ + isDirectory: () => true, + } as fs.Stats); + const invocation = lsTool.build(params); + expect(invocation).toBeDefined(); }); }); @@ -133,10 +137,8 @@ describe('LSTool', () => { vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('[DIR] subdir'); expect(result.llmContent).toContain('file1.ts'); @@ -161,10 +163,8 @@ describe('LSTool', () => { vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('module1.js'); expect(result.llmContent).toContain('module2.js'); @@ -179,10 +179,8 @@ describe('LSTool', () => { } as fs.Stats); vi.mocked(fs.readdirSync).mockReturnValue([]); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toBe( 'Directory /home/user/project/empty is empty.', @@ -207,10 +205,11 @@ describe('LSTool', () => { }); vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - const result = await lsTool.execute( - { path: testPath, ignore: ['*.spec.js'] }, - new AbortController().signal, - ); + const invocation = lsTool.build({ + path: testPath, + ignore: ['*.spec.js'], + }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('test.js'); expect(result.llmContent).toContain('index.js'); @@ -238,10 +237,8 @@ describe('LSTool', () => { (path: string) => path.includes('ignored.js'), ); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('file1.js'); expect(result.llmContent).toContain('file2.js'); @@ -269,10 +266,8 @@ describe('LSTool', () => { (path: string) => path.includes('private.js'), ); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('file1.js'); expect(result.llmContent).toContain('file2.js'); @@ -287,10 +282,8 @@ describe('LSTool', () => { isDirectory: () => false, } as fs.Stats); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('Path is not a directory'); expect(result.returnDisplay).toBe('Error: Path is not a directory.'); @@ -303,10 +296,8 @@ describe('LSTool', () => { throw new Error('ENOENT: no such file or directory'); }); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('Error listing directory'); expect(result.returnDisplay).toBe('Error: Failed to list directory.'); @@ -336,10 +327,8 @@ describe('LSTool', () => { vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); const lines = ( typeof result.llmContent === 'string' ? result.llmContent : '' @@ -361,24 +350,18 @@ describe('LSTool', () => { throw new Error('EACCES: permission denied'); }); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('Error listing directory'); expect(result.llmContent).toContain('permission denied'); expect(result.returnDisplay).toBe('Error: Failed to list directory.'); }); - it('should validate parameters and return error for invalid params', async () => { - const result = await lsTool.execute( - { path: '../outside' }, - new AbortController().signal, + it('should throw for invalid params at build time', async () => { + expect(() => lsTool.build({ path: '../outside' })).toThrow( + 'Path must be absolute: ../outside', ); - - expect(result.llmContent).toContain('Invalid parameters provided'); - expect(result.returnDisplay).toBe('Error: Failed to execute tool.'); }); it('should handle errors accessing individual files during listing', async () => { @@ -406,10 +389,8 @@ describe('LSTool', () => { .spyOn(console, 'error') .mockImplementation(() => {}); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); // Should still list the accessible file expect(result.llmContent).toContain('accessible.ts'); @@ -428,19 +409,25 @@ describe('LSTool', () => { describe('getDescription', () => { it('should return shortened relative path', () => { const params = { - path: path.join(mockPrimaryDir, 'deeply', 'nested', 'directory'), + path: `${mockPrimaryDir}/deeply/nested/directory`, }; - - const description = lsTool.getDescription(params); + vi.mocked(fs.statSync).mockReturnValue({ + isDirectory: () => true, + } as fs.Stats); + const invocation = lsTool.build(params); + const description = invocation.getDescription(); expect(description).toBe(path.join('deeply', 'nested', 'directory')); }); it('should handle paths in secondary workspace', () => { const params = { - path: path.join(mockSecondaryDir, 'lib'), + path: `${mockSecondaryDir}/lib`, }; - - const description = lsTool.getDescription(params); + vi.mocked(fs.statSync).mockReturnValue({ + isDirectory: () => true, + } as fs.Stats); + const invocation = lsTool.build(params); + const description = invocation.getDescription(); expect(description).toBe(path.join('..', 'other-project', 'lib')); }); }); @@ -448,22 +435,25 @@ describe('LSTool', () => { describe('workspace boundary validation', () => { it('should accept paths in primary workspace directory', () => { const params = { path: `${mockPrimaryDir}/src` }; - expect(lsTool.validateToolParams(params)).toBeNull(); + vi.mocked(fs.statSync).mockReturnValue({ + isDirectory: () => true, + } as fs.Stats); + expect(lsTool.build(params)).toBeDefined(); }); it('should accept paths in secondary workspace directory', () => { const params = { path: `${mockSecondaryDir}/lib` }; - expect(lsTool.validateToolParams(params)).toBeNull(); + vi.mocked(fs.statSync).mockReturnValue({ + isDirectory: () => true, + } as fs.Stats); + expect(lsTool.build(params)).toBeDefined(); }); it('should reject paths outside all workspace directories', () => { const params = { path: '/etc/passwd' }; - const error = lsTool.validateToolParams(params); - expect(error).toContain( + expect(() => lsTool.build(params)).toThrow( 'Path must be within one of the workspace directories', ); - expect(error).toContain(mockPrimaryDir); - expect(error).toContain(mockSecondaryDir); }); it('should list files from secondary workspace directory', async () => { @@ -483,10 +473,8 @@ describe('LSTool', () => { vi.mocked(fs.readdirSync).mockReturnValue(mockFiles as any); - const result = await lsTool.execute( - { path: testPath }, - new AbortController().signal, - ); + const invocation = lsTool.build({ path: testPath }); + const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain('test1.spec.ts'); expect(result.llmContent).toContain('test2.spec.ts'); diff --git a/packages/core/src/tools/ls.ts b/packages/core/src/tools/ls.ts index 7a4445a5..2618136a 100644 --- a/packages/core/src/tools/ls.ts +++ b/packages/core/src/tools/ls.ts @@ -6,7 +6,13 @@ import fs from 'fs'; import path from 'path'; -import { BaseTool, Kind, ToolResult } from './tools.js'; +import { + BaseDeclarativeTool, + BaseToolInvocation, + Kind, + ToolInvocation, + ToolResult, +} from './tools.js'; import { SchemaValidator } from '../utils/schemaValidator.js'; import { makeRelative, shortenPath } from '../utils/paths.js'; import { Config, DEFAULT_FILE_FILTERING_OPTIONS } from '../config/config.js'; @@ -64,10 +70,199 @@ export interface FileEntry { modifiedTime: Date; } +class LSToolInvocation extends BaseToolInvocation { + constructor( + private readonly config: Config, + params: LSToolParams, + ) { + super(params); + } + + /** + * Checks if a filename matches any of the ignore patterns + * @param filename Filename to check + * @param patterns Array of glob patterns to check against + * @returns True if the filename should be ignored + */ + private shouldIgnore(filename: string, patterns?: string[]): boolean { + if (!patterns || patterns.length === 0) { + return false; + } + for (const pattern of patterns) { + // Convert glob pattern to RegExp + const regexPattern = pattern + .replace(/[.+^${}()|[\]\\]/g, '\\$&') + .replace(/\*/g, '.*') + .replace(/\?/g, '.'); + const regex = new RegExp(`^${regexPattern}$`); + if (regex.test(filename)) { + return true; + } + } + return false; + } + + /** + * Gets a description of the file reading operation + * @returns A string describing the file being read + */ + getDescription(): string { + const relativePath = makeRelative( + this.params.path, + this.config.getTargetDir(), + ); + return shortenPath(relativePath); + } + + // Helper for consistent error formatting + private errorResult(llmContent: string, returnDisplay: string): ToolResult { + return { + llmContent, + // Keep returnDisplay simpler in core logic + returnDisplay: `Error: ${returnDisplay}`, + }; + } + + /** + * Executes the LS operation with the given parameters + * @returns Result of the LS operation + */ + async execute(_signal: AbortSignal): Promise { + try { + const stats = fs.statSync(this.params.path); + if (!stats) { + // fs.statSync throws on non-existence, so this check might be redundant + // but keeping for clarity. Error message adjusted. + return this.errorResult( + `Error: Directory not found or inaccessible: ${this.params.path}`, + `Directory not found or inaccessible.`, + ); + } + if (!stats.isDirectory()) { + return this.errorResult( + `Error: Path is not a directory: ${this.params.path}`, + `Path is not a directory.`, + ); + } + + const files = fs.readdirSync(this.params.path); + + const defaultFileIgnores = + this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS; + + const fileFilteringOptions = { + respectGitIgnore: + this.params.file_filtering_options?.respect_git_ignore ?? + defaultFileIgnores.respectGitIgnore, + respectGeminiIgnore: + this.params.file_filtering_options?.respect_gemini_ignore ?? + defaultFileIgnores.respectGeminiIgnore, + }; + + // Get centralized file discovery service + + const fileDiscovery = this.config.getFileService(); + + const entries: FileEntry[] = []; + let gitIgnoredCount = 0; + let geminiIgnoredCount = 0; + + if (files.length === 0) { + // Changed error message to be more neutral for LLM + return { + llmContent: `Directory ${this.params.path} is empty.`, + returnDisplay: `Directory is empty.`, + }; + } + + for (const file of files) { + if (this.shouldIgnore(file, this.params.ignore)) { + continue; + } + + const fullPath = path.join(this.params.path, file); + const relativePath = path.relative( + this.config.getTargetDir(), + fullPath, + ); + + // Check if this file should be ignored based on git or gemini ignore rules + if ( + fileFilteringOptions.respectGitIgnore && + fileDiscovery.shouldGitIgnoreFile(relativePath) + ) { + gitIgnoredCount++; + continue; + } + if ( + fileFilteringOptions.respectGeminiIgnore && + fileDiscovery.shouldGeminiIgnoreFile(relativePath) + ) { + geminiIgnoredCount++; + continue; + } + + try { + const stats = fs.statSync(fullPath); + const isDir = stats.isDirectory(); + entries.push({ + name: file, + path: fullPath, + isDirectory: isDir, + size: isDir ? 0 : stats.size, + modifiedTime: stats.mtime, + }); + } catch (error) { + // Log error internally but don't fail the whole listing + console.error(`Error accessing ${fullPath}: ${error}`); + } + } + + // Sort entries (directories first, then alphabetically) + entries.sort((a, b) => { + if (a.isDirectory && !b.isDirectory) return -1; + if (!a.isDirectory && b.isDirectory) return 1; + return a.name.localeCompare(b.name); + }); + + // Create formatted content for LLM + const directoryContent = entries + .map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`) + .join('\n'); + + let resultMessage = `Directory listing for ${this.params.path}:\n${directoryContent}`; + const ignoredMessages = []; + if (gitIgnoredCount > 0) { + ignoredMessages.push(`${gitIgnoredCount} git-ignored`); + } + if (geminiIgnoredCount > 0) { + ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`); + } + + if (ignoredMessages.length > 0) { + resultMessage += `\n\n(${ignoredMessages.join(', ')})`; + } + + let displayMessage = `Listed ${entries.length} item(s).`; + if (ignoredMessages.length > 0) { + displayMessage += ` (${ignoredMessages.join(', ')})`; + } + + return { + llmContent: resultMessage, + returnDisplay: displayMessage, + }; + } catch (error) { + const errorMsg = `Error listing directory: ${error instanceof Error ? error.message : String(error)}`; + return this.errorResult(errorMsg, 'Failed to list directory.'); + } + } +} + /** * Implementation of the LS tool logic */ -export class LSTool extends BaseTool { +export class LSTool extends BaseDeclarativeTool { static readonly Name = 'list_directory'; constructor(private config: Config) { @@ -134,198 +329,16 @@ export class LSTool extends BaseTool { const workspaceContext = this.config.getWorkspaceContext(); if (!workspaceContext.isPathWithinWorkspace(params.path)) { const directories = workspaceContext.getDirectories(); - return `Path must be within one of the workspace directories: ${directories.join(', ')}`; + return `Path must be within one of the workspace directories: ${directories.join( + ', ', + )}`; } return null; } - /** - * Checks if a filename matches any of the ignore patterns - * @param filename Filename to check - * @param patterns Array of glob patterns to check against - * @returns True if the filename should be ignored - */ - private shouldIgnore(filename: string, patterns?: string[]): boolean { - if (!patterns || patterns.length === 0) { - return false; - } - for (const pattern of patterns) { - // Convert glob pattern to RegExp - const regexPattern = pattern - .replace(/[.+^${}()|[\]\\]/g, '\\$&') - .replace(/\*/g, '.*') - .replace(/\?/g, '.'); - const regex = new RegExp(`^${regexPattern}$`); - if (regex.test(filename)) { - return true; - } - } - return false; - } - - /** - * Gets a description of the file reading operation - * @param params Parameters for the file reading - * @returns A string describing the file being read - */ - getDescription(params: LSToolParams): string { - const relativePath = makeRelative(params.path, this.config.getTargetDir()); - return shortenPath(relativePath); - } - - // Helper for consistent error formatting - private errorResult(llmContent: string, returnDisplay: string): ToolResult { - return { - llmContent, - // Keep returnDisplay simpler in core logic - returnDisplay: `Error: ${returnDisplay}`, - }; - } - - /** - * Executes the LS operation with the given parameters - * @param params Parameters for the LS operation - * @returns Result of the LS operation - */ - async execute( + protected createInvocation( params: LSToolParams, - _signal: AbortSignal, - ): Promise { - const validationError = this.validateToolParams(params); - if (validationError) { - return this.errorResult( - `Error: Invalid parameters provided. Reason: ${validationError}`, - `Failed to execute tool.`, - ); - } - - try { - const stats = fs.statSync(params.path); - if (!stats) { - // fs.statSync throws on non-existence, so this check might be redundant - // but keeping for clarity. Error message adjusted. - return this.errorResult( - `Error: Directory not found or inaccessible: ${params.path}`, - `Directory not found or inaccessible.`, - ); - } - if (!stats.isDirectory()) { - return this.errorResult( - `Error: Path is not a directory: ${params.path}`, - `Path is not a directory.`, - ); - } - - const files = fs.readdirSync(params.path); - - const defaultFileIgnores = - this.config.getFileFilteringOptions() ?? DEFAULT_FILE_FILTERING_OPTIONS; - - const fileFilteringOptions = { - respectGitIgnore: - params.file_filtering_options?.respect_git_ignore ?? - defaultFileIgnores.respectGitIgnore, - respectGeminiIgnore: - params.file_filtering_options?.respect_gemini_ignore ?? - defaultFileIgnores.respectGeminiIgnore, - }; - - // Get centralized file discovery service - - const fileDiscovery = this.config.getFileService(); - - const entries: FileEntry[] = []; - let gitIgnoredCount = 0; - let geminiIgnoredCount = 0; - - if (files.length === 0) { - // Changed error message to be more neutral for LLM - return { - llmContent: `Directory ${params.path} is empty.`, - returnDisplay: `Directory is empty.`, - }; - } - - for (const file of files) { - if (this.shouldIgnore(file, params.ignore)) { - continue; - } - - const fullPath = path.join(params.path, file); - const relativePath = path.relative( - this.config.getTargetDir(), - fullPath, - ); - - // Check if this file should be ignored based on git or gemini ignore rules - if ( - fileFilteringOptions.respectGitIgnore && - fileDiscovery.shouldGitIgnoreFile(relativePath) - ) { - gitIgnoredCount++; - continue; - } - if ( - fileFilteringOptions.respectGeminiIgnore && - fileDiscovery.shouldGeminiIgnoreFile(relativePath) - ) { - geminiIgnoredCount++; - continue; - } - - try { - const stats = fs.statSync(fullPath); - const isDir = stats.isDirectory(); - entries.push({ - name: file, - path: fullPath, - isDirectory: isDir, - size: isDir ? 0 : stats.size, - modifiedTime: stats.mtime, - }); - } catch (error) { - // Log error internally but don't fail the whole listing - console.error(`Error accessing ${fullPath}: ${error}`); - } - } - - // Sort entries (directories first, then alphabetically) - entries.sort((a, b) => { - if (a.isDirectory && !b.isDirectory) return -1; - if (!a.isDirectory && b.isDirectory) return 1; - return a.name.localeCompare(b.name); - }); - - // Create formatted content for LLM - const directoryContent = entries - .map((entry) => `${entry.isDirectory ? '[DIR] ' : ''}${entry.name}`) - .join('\n'); - - let resultMessage = `Directory listing for ${params.path}:\n${directoryContent}`; - const ignoredMessages = []; - if (gitIgnoredCount > 0) { - ignoredMessages.push(`${gitIgnoredCount} git-ignored`); - } - if (geminiIgnoredCount > 0) { - ignoredMessages.push(`${geminiIgnoredCount} gemini-ignored`); - } - - if (ignoredMessages.length > 0) { - resultMessage += `\n\n(${ignoredMessages.join(', ')})`; - } - - let displayMessage = `Listed ${entries.length} item(s).`; - if (ignoredMessages.length > 0) { - displayMessage += ` (${ignoredMessages.join(', ')})`; - } - - return { - llmContent: resultMessage, - returnDisplay: displayMessage, - }; - } catch (error) { - const errorMsg = `Error listing directory: ${error instanceof Error ? error.message : String(error)}`; - return this.errorResult(errorMsg, 'Failed to list directory.'); - } + ): ToolInvocation { + return new LSToolInvocation(this.config, params); } } diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index f8a9a8ba..36602d49 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -73,11 +73,21 @@ describe('DiscoveredMCPTool', () => { required: ['param'], }; + let tool: DiscoveredMCPTool; + beforeEach(() => { mockCallTool.mockClear(); mockToolMethod.mockClear(); + tool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + ); // Clear allowlist before each relevant test, especially for shouldConfirmExecute - (DiscoveredMCPTool as any).allowlist.clear(); + const invocation = tool.build({}) as any; + invocation.constructor.allowlist.clear(); }); afterEach(() => { @@ -86,14 +96,6 @@ describe('DiscoveredMCPTool', () => { describe('constructor', () => { it('should set properties correctly', () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); - expect(tool.name).toBe(serverToolName); expect(tool.schema.name).toBe(serverToolName); expect(tool.schema.description).toBe(baseDescription); @@ -105,7 +107,7 @@ describe('DiscoveredMCPTool', () => { it('should accept and store a custom timeout', () => { const customTimeout = 5000; - const tool = new DiscoveredMCPTool( + const toolWithTimeout = new DiscoveredMCPTool( mockCallableToolInstance, serverName, serverToolName, @@ -113,19 +115,12 @@ describe('DiscoveredMCPTool', () => { inputSchema, customTimeout, ); - expect(tool.timeout).toBe(customTimeout); + expect(toolWithTimeout.timeout).toBe(customTimeout); }); }); describe('execute', () => { it('should call mcpTool.callTool with correct parameters and format display output', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { param: 'testValue' }; const mockToolSuccessResultObject = { success: true, @@ -147,7 +142,10 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(mockMcpToolResponseParts); - const toolResult: ToolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult: ToolResult = await invocation.execute( + new AbortController().signal, + ); expect(mockCallTool).toHaveBeenCalledWith([ { name: serverToolName, args: params }, @@ -163,17 +161,13 @@ describe('DiscoveredMCPTool', () => { }); it('should handle empty result from getStringifiedResultForDisplay', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { param: 'testValue' }; const mockMcpToolResponsePartsEmpty: Part[] = []; mockCallTool.mockResolvedValue(mockMcpToolResponsePartsEmpty); - const toolResult: ToolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult: ToolResult = await invocation.execute( + new AbortController().signal, + ); expect(toolResult.returnDisplay).toBe('```json\n[]\n```'); expect(toolResult.llmContent).toEqual([ { text: '[Error: Could not parse tool response]' }, @@ -181,28 +175,17 @@ describe('DiscoveredMCPTool', () => { }); it('should propagate rejection if mcpTool.callTool rejects', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { param: 'failCase' }; const expectedError = new Error('MCP call failed'); mockCallTool.mockRejectedValue(expectedError); - await expect(tool.execute(params)).rejects.toThrow(expectedError); + const invocation = tool.build(params); + await expect( + invocation.execute(new AbortController().signal), + ).rejects.toThrow(expectedError); }); it('should handle a simple text response correctly', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { query: 'test' }; const successMessage = 'This is a success message.'; @@ -221,7 +204,8 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(sdkResponse); - const toolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult = await invocation.execute(new AbortController().signal); // 1. Assert that the llmContent sent to the scheduler is a clean Part array. expect(toolResult.llmContent).toEqual([{ text: successMessage }]); @@ -236,13 +220,6 @@ describe('DiscoveredMCPTool', () => { }); it('should handle an AudioBlock response', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { action: 'play' }; const sdkResponse: Part[] = [ { @@ -262,7 +239,8 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(sdkResponse); - const toolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult = await invocation.execute(new AbortController().signal); expect(toolResult.llmContent).toEqual([ { @@ -279,13 +257,6 @@ describe('DiscoveredMCPTool', () => { }); it('should handle a ResourceLinkBlock response', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { resource: 'get' }; const sdkResponse: Part[] = [ { @@ -306,7 +277,8 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(sdkResponse); - const toolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult = await invocation.execute(new AbortController().signal); expect(toolResult.llmContent).toEqual([ { @@ -319,13 +291,6 @@ describe('DiscoveredMCPTool', () => { }); it('should handle an embedded text ResourceBlock response', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { resource: 'get' }; const sdkResponse: Part[] = [ { @@ -348,7 +313,8 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(sdkResponse); - const toolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult = await invocation.execute(new AbortController().signal); expect(toolResult.llmContent).toEqual([ { text: 'This is the text content.' }, @@ -357,13 +323,6 @@ describe('DiscoveredMCPTool', () => { }); it('should handle an embedded binary ResourceBlock response', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { resource: 'get' }; const sdkResponse: Part[] = [ { @@ -386,7 +345,8 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(sdkResponse); - const toolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult = await invocation.execute(new AbortController().signal); expect(toolResult.llmContent).toEqual([ { @@ -405,13 +365,6 @@ describe('DiscoveredMCPTool', () => { }); it('should handle a mix of content block types', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { action: 'complex' }; const sdkResponse: Part[] = [ { @@ -433,7 +386,8 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(sdkResponse); - const toolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult = await invocation.execute(new AbortController().signal); expect(toolResult.llmContent).toEqual([ { text: 'First part.' }, @@ -454,13 +408,6 @@ describe('DiscoveredMCPTool', () => { }); it('should ignore unknown content block types', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { action: 'test' }; const sdkResponse: Part[] = [ { @@ -477,7 +424,8 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(sdkResponse); - const toolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult = await invocation.execute(new AbortController().signal); expect(toolResult.llmContent).toEqual([{ text: 'Valid part.' }]); expect(toolResult.returnDisplay).toBe( @@ -486,13 +434,6 @@ describe('DiscoveredMCPTool', () => { }); it('should handle a complex mix of content block types', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const params = { action: 'super-complex' }; const sdkResponse: Part[] = [ { @@ -527,7 +468,8 @@ describe('DiscoveredMCPTool', () => { ]; mockCallTool.mockResolvedValue(sdkResponse); - const toolResult = await tool.execute(params); + const invocation = tool.build(params); + const toolResult = await invocation.execute(new AbortController().signal); expect(toolResult.llmContent).toEqual([ { text: 'Here is a resource.' }, @@ -552,10 +494,8 @@ describe('DiscoveredMCPTool', () => { }); describe('shouldConfirmExecute', () => { - // beforeEach is already clearing allowlist - it('should return false if trust is true', async () => { - const tool = new DiscoveredMCPTool( + const trustedTool = new DiscoveredMCPTool( mockCallableToolInstance, serverName, serverToolName, @@ -564,50 +504,32 @@ describe('DiscoveredMCPTool', () => { undefined, true, ); + const invocation = trustedTool.build({}); expect( - await tool.shouldConfirmExecute({}, new AbortController().signal), + await invocation.shouldConfirmExecute(new AbortController().signal), ).toBe(false); }); it('should return false if server is allowlisted', async () => { - (DiscoveredMCPTool as any).allowlist.add(serverName); - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); + const invocation = tool.build({}) as any; + invocation.constructor.allowlist.add(serverName); expect( - await tool.shouldConfirmExecute({}, new AbortController().signal), + await invocation.shouldConfirmExecute(new AbortController().signal), ).toBe(false); }); it('should return false if tool is allowlisted', async () => { const toolAllowlistKey = `${serverName}.${serverToolName}`; - (DiscoveredMCPTool as any).allowlist.add(toolAllowlistKey); - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); + const invocation = tool.build({}) as any; + invocation.constructor.allowlist.add(toolAllowlistKey); expect( - await tool.shouldConfirmExecute({}, new AbortController().signal), + await invocation.shouldConfirmExecute(new AbortController().signal), ).toBe(false); }); it('should return confirmation details if not trusted and not allowlisted', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); - const confirmation = await tool.shouldConfirmExecute( - {}, + const invocation = tool.build({}); + const confirmation = await invocation.shouldConfirmExecute( new AbortController().signal, ); expect(confirmation).not.toBe(false); @@ -629,15 +551,8 @@ describe('DiscoveredMCPTool', () => { }); it('should add server to allowlist on ProceedAlwaysServer', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); - const confirmation = await tool.shouldConfirmExecute( - {}, + const invocation = tool.build({}) as any; + const confirmation = await invocation.shouldConfirmExecute( new AbortController().signal, ); expect(confirmation).not.toBe(false); @@ -650,7 +565,7 @@ describe('DiscoveredMCPTool', () => { await confirmation.onConfirm( ToolConfirmationOutcome.ProceedAlwaysServer, ); - expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe(true); + expect(invocation.constructor.allowlist.has(serverName)).toBe(true); } else { throw new Error( 'Confirmation details or onConfirm not in expected format', @@ -659,16 +574,9 @@ describe('DiscoveredMCPTool', () => { }); it('should add tool to allowlist on ProceedAlwaysTool', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); const toolAllowlistKey = `${serverName}.${serverToolName}`; - const confirmation = await tool.shouldConfirmExecute( - {}, + const invocation = tool.build({}) as any; + const confirmation = await invocation.shouldConfirmExecute( new AbortController().signal, ); expect(confirmation).not.toBe(false); @@ -679,7 +587,7 @@ describe('DiscoveredMCPTool', () => { typeof confirmation.onConfirm === 'function' ) { await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlwaysTool); - expect((DiscoveredMCPTool as any).allowlist.has(toolAllowlistKey)).toBe( + expect(invocation.constructor.allowlist.has(toolAllowlistKey)).toBe( true, ); } else { @@ -690,15 +598,8 @@ describe('DiscoveredMCPTool', () => { }); it('should handle Cancel confirmation outcome', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); - const confirmation = await tool.shouldConfirmExecute( - {}, + const invocation = tool.build({}) as any; + const confirmation = await invocation.shouldConfirmExecute( new AbortController().signal, ); expect(confirmation).not.toBe(false); @@ -710,11 +611,9 @@ describe('DiscoveredMCPTool', () => { ) { // Cancel should not add anything to allowlist await confirmation.onConfirm(ToolConfirmationOutcome.Cancel); - expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe( - false, - ); + expect(invocation.constructor.allowlist.has(serverName)).toBe(false); expect( - (DiscoveredMCPTool as any).allowlist.has( + invocation.constructor.allowlist.has( `${serverName}.${serverToolName}`, ), ).toBe(false); @@ -726,15 +625,8 @@ describe('DiscoveredMCPTool', () => { }); it('should handle ProceedOnce confirmation outcome', async () => { - const tool = new DiscoveredMCPTool( - mockCallableToolInstance, - serverName, - serverToolName, - baseDescription, - inputSchema, - ); - const confirmation = await tool.shouldConfirmExecute( - {}, + const invocation = tool.build({}) as any; + const confirmation = await invocation.shouldConfirmExecute( new AbortController().signal, ); expect(confirmation).not.toBe(false); @@ -746,11 +638,9 @@ describe('DiscoveredMCPTool', () => { ) { // ProceedOnce should not add anything to allowlist await confirmation.onConfirm(ToolConfirmationOutcome.ProceedOnce); - expect((DiscoveredMCPTool as any).allowlist.has(serverName)).toBe( - false, - ); + expect(invocation.constructor.allowlist.has(serverName)).toBe(false); expect( - (DiscoveredMCPTool as any).allowlist.has( + invocation.constructor.allowlist.has( `${serverName}.${serverToolName}`, ), ).toBe(false); diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 59f83db3..01a8d75c 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -5,14 +5,16 @@ */ import { - BaseTool, - ToolResult, + BaseDeclarativeTool, + BaseToolInvocation, + Kind, ToolCallConfirmationDetails, ToolConfirmationOutcome, + ToolInvocation, ToolMcpConfirmationDetails, - Kind, + ToolResult, } from './tools.js'; -import { CallableTool, Part, FunctionCall } from '@google/genai'; +import { CallableTool, FunctionCall, Part } from '@google/genai'; type ToolParams = Record; @@ -50,9 +52,84 @@ type McpContentBlock = | McpResourceBlock | McpResourceLinkBlock; -export class DiscoveredMCPTool extends BaseTool { +class DiscoveredMCPToolInvocation extends BaseToolInvocation< + ToolParams, + ToolResult +> { private static readonly allowlist: Set = new Set(); + constructor( + private readonly mcpTool: CallableTool, + readonly serverName: string, + readonly serverToolName: string, + readonly displayName: string, + readonly timeout?: number, + readonly trust?: boolean, + params: ToolParams = {}, + ) { + super(params); + } + + async shouldConfirmExecute( + _abortSignal: AbortSignal, + ): Promise { + const serverAllowListKey = this.serverName; + const toolAllowListKey = `${this.serverName}.${this.serverToolName}`; + + if (this.trust) { + return false; // server is trusted, no confirmation needed + } + + if ( + DiscoveredMCPToolInvocation.allowlist.has(serverAllowListKey) || + DiscoveredMCPToolInvocation.allowlist.has(toolAllowListKey) + ) { + return false; // server and/or tool already allowlisted + } + + const confirmationDetails: ToolMcpConfirmationDetails = { + type: 'mcp', + title: 'Confirm MCP Tool Execution', + serverName: this.serverName, + toolName: this.serverToolName, // Display original tool name in confirmation + toolDisplayName: this.displayName, // Display global registry name exposed to model and user + onConfirm: async (outcome: ToolConfirmationOutcome) => { + if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) { + DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey); + } else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) { + DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey); + } + }, + }; + return confirmationDetails; + } + + async execute(): Promise { + const functionCalls: FunctionCall[] = [ + { + name: this.serverToolName, + args: this.params, + }, + ]; + + const rawResponseParts = await this.mcpTool.callTool(functionCalls); + const transformedParts = transformMcpContentToParts(rawResponseParts); + + return { + llmContent: transformedParts, + returnDisplay: getStringifiedResultForDisplay(rawResponseParts), + }; + } + + getDescription(): string { + return this.displayName; + } +} + +export class DiscoveredMCPTool extends BaseDeclarativeTool< + ToolParams, + ToolResult +> { constructor( private readonly mcpTool: CallableTool, readonly serverName: string, @@ -87,56 +164,18 @@ export class DiscoveredMCPTool extends BaseTool { ); } - async shouldConfirmExecute( - _params: ToolParams, - _abortSignal: AbortSignal, - ): Promise { - const serverAllowListKey = this.serverName; - const toolAllowListKey = `${this.serverName}.${this.serverToolName}`; - - if (this.trust) { - return false; // server is trusted, no confirmation needed - } - - if ( - DiscoveredMCPTool.allowlist.has(serverAllowListKey) || - DiscoveredMCPTool.allowlist.has(toolAllowListKey) - ) { - return false; // server and/or tool already allowlisted - } - - const confirmationDetails: ToolMcpConfirmationDetails = { - type: 'mcp', - title: 'Confirm MCP Tool Execution', - serverName: this.serverName, - toolName: this.serverToolName, // Display original tool name in confirmation - toolDisplayName: this.name, // Display global registry name exposed to model and user - onConfirm: async (outcome: ToolConfirmationOutcome) => { - if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) { - DiscoveredMCPTool.allowlist.add(serverAllowListKey); - } else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) { - DiscoveredMCPTool.allowlist.add(toolAllowListKey); - } - }, - }; - return confirmationDetails; - } - - async execute(params: ToolParams): Promise { - const functionCalls: FunctionCall[] = [ - { - name: this.serverToolName, - args: params, - }, - ]; - - const rawResponseParts = await this.mcpTool.callTool(functionCalls); - const transformedParts = transformMcpContentToParts(rawResponseParts); - - return { - llmContent: transformedParts, - returnDisplay: getStringifiedResultForDisplay(rawResponseParts), - }; + protected createInvocation( + params: ToolParams, + ): ToolInvocation { + return new DiscoveredMCPToolInvocation( + this.mcpTool, + this.serverName, + this.serverToolName, + this.displayName, + this.timeout, + this.trust, + params, + ); } } diff --git a/packages/core/src/tools/memoryTool.test.ts b/packages/core/src/tools/memoryTool.test.ts index 2a5c4c39..0e382325 100644 --- a/packages/core/src/tools/memoryTool.test.ts +++ b/packages/core/src/tools/memoryTool.test.ts @@ -218,7 +218,8 @@ describe('MemoryTool', () => { it('should call performAddMemoryEntry with correct parameters and return success', async () => { const params = { fact: 'The sky is blue' }; - const result = await memoryTool.execute(params, mockAbortSignal); + const invocation = memoryTool.build(params); + const result = await invocation.execute(mockAbortSignal); // Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test const expectedFilePath = path.join( os.homedir(), @@ -247,14 +248,12 @@ describe('MemoryTool', () => { it('should return an error if fact is empty', async () => { const params = { fact: ' ' }; // Empty fact - const result = await memoryTool.execute(params, mockAbortSignal); - const errorMessage = 'Parameter "fact" must be a non-empty string.'; - - expect(performAddMemoryEntrySpy).not.toHaveBeenCalled(); - expect(result.llmContent).toBe( - JSON.stringify({ success: false, error: errorMessage }), + expect(memoryTool.validateToolParams(params)).toBe( + 'Parameter "fact" must be a non-empty string.', + ); + expect(() => memoryTool.build(params)).toThrow( + 'Parameter "fact" must be a non-empty string.', ); - expect(result.returnDisplay).toBe(`Error: ${errorMessage}`); }); it('should handle errors from performAddMemoryEntry', async () => { @@ -264,7 +263,8 @@ describe('MemoryTool', () => { ); performAddMemoryEntrySpy.mockRejectedValue(underlyingError); - const result = await memoryTool.execute(params, mockAbortSignal); + const invocation = memoryTool.build(params); + const result = await invocation.execute(mockAbortSignal); expect(result.llmContent).toBe( JSON.stringify({ @@ -284,17 +284,17 @@ describe('MemoryTool', () => { beforeEach(() => { memoryTool = new MemoryTool(); // Clear the allowlist before each test - (MemoryTool as unknown as { allowlist: Set }).allowlist.clear(); + const invocation = memoryTool.build({ fact: 'mock-fact' }); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (invocation.constructor as any).allowlist.clear(); // Mock fs.readFile to return empty string (file doesn't exist) vi.mocked(fs.readFile).mockResolvedValue(''); }); it('should return confirmation details when memory file is not allowlisted', async () => { const params = { fact: 'Test fact' }; - const result = await memoryTool.shouldConfirmExecute( - params, - mockAbortSignal, - ); + const invocation = memoryTool.build(params); + const result = await invocation.shouldConfirmExecute(mockAbortSignal); expect(result).toBeDefined(); expect(result).not.toBe(false); @@ -321,15 +321,12 @@ describe('MemoryTool', () => { getCurrentGeminiMdFilename(), ); + const invocation = memoryTool.build(params); // Add the memory file to the allowlist - (MemoryTool as unknown as { allowlist: Set }).allowlist.add( - memoryFilePath, - ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (invocation.constructor as any).allowlist.add(memoryFilePath); - const result = await memoryTool.shouldConfirmExecute( - params, - mockAbortSignal, - ); + const result = await invocation.shouldConfirmExecute(mockAbortSignal); expect(result).toBe(false); }); @@ -342,10 +339,8 @@ describe('MemoryTool', () => { getCurrentGeminiMdFilename(), ); - const result = await memoryTool.shouldConfirmExecute( - params, - mockAbortSignal, - ); + const invocation = memoryTool.build(params); + const result = await invocation.shouldConfirmExecute(mockAbortSignal); expect(result).toBeDefined(); expect(result).not.toBe(false); @@ -356,9 +351,8 @@ describe('MemoryTool', () => { // Check that the memory file was added to the allowlist expect( - (MemoryTool as unknown as { allowlist: Set }).allowlist.has( - memoryFilePath, - ), + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (invocation.constructor as any).allowlist.has(memoryFilePath), ).toBe(true); } }); @@ -371,10 +365,8 @@ describe('MemoryTool', () => { getCurrentGeminiMdFilename(), ); - const result = await memoryTool.shouldConfirmExecute( - params, - mockAbortSignal, - ); + const invocation = memoryTool.build(params); + const result = await invocation.shouldConfirmExecute(mockAbortSignal); expect(result).toBeDefined(); expect(result).not.toBe(false); @@ -382,18 +374,12 @@ describe('MemoryTool', () => { if (result && result.type === 'edit') { // Simulate the onConfirm callback with different outcomes await result.onConfirm(ToolConfirmationOutcome.ProceedOnce); - expect( - (MemoryTool as unknown as { allowlist: Set }).allowlist.has( - memoryFilePath, - ), - ).toBe(false); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const allowlist = (invocation.constructor as any).allowlist; + expect(allowlist.has(memoryFilePath)).toBe(false); await result.onConfirm(ToolConfirmationOutcome.Cancel); - expect( - (MemoryTool as unknown as { allowlist: Set }).allowlist.has( - memoryFilePath, - ), - ).toBe(false); + expect(allowlist.has(memoryFilePath)).toBe(false); } }); @@ -405,10 +391,8 @@ describe('MemoryTool', () => { // Mock fs.readFile to return existing content vi.mocked(fs.readFile).mockResolvedValue(existingContent); - const result = await memoryTool.shouldConfirmExecute( - params, - mockAbortSignal, - ); + const invocation = memoryTool.build(params); + const result = await invocation.shouldConfirmExecute(mockAbortSignal); expect(result).toBeDefined(); expect(result).not.toBe(false); diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index c8e88c97..a9d765c4 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -5,11 +5,12 @@ */ import { - BaseTool, + BaseDeclarativeTool, + BaseToolInvocation, Kind, - ToolResult, ToolEditConfirmationDetails, ToolConfirmationOutcome, + ToolResult, } from './tools.js'; import { FunctionDeclaration } from '@google/genai'; import * as fs from 'fs/promises'; @@ -19,6 +20,7 @@ import * as Diff from 'diff'; import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; import { tildeifyPath } from '../utils/paths.js'; import { ModifiableDeclarativeTool, ModifyContext } from './modifiable-tool.js'; +import { SchemaValidator } from '../utils/schemaValidator.js'; const memoryToolSchemaData: FunctionDeclaration = { name: 'save_memory', @@ -110,101 +112,86 @@ function ensureNewlineSeparation(currentContent: string): string { return '\n\n'; } -export class MemoryTool - extends BaseTool - implements ModifiableDeclarativeTool -{ - private static readonly allowlist: Set = new Set(); +/** + * Reads the current content of the memory file + */ +async function readMemoryFileContent(): Promise { + try { + return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); + } catch (err) { + const error = err as Error & { code?: string }; + if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; + return ''; + } +} - static readonly Name: string = memoryToolSchemaData.name!; - constructor() { - super( - MemoryTool.Name, - 'Save Memory', - memoryToolDescription, - Kind.Think, - memoryToolSchemaData.parametersJsonSchema as Record, +/** + * Computes the new content that would result from adding a memory entry + */ +function computeNewContent(currentContent: string, fact: string): string { + let processedText = fact.trim(); + processedText = processedText.replace(/^(-+\s*)+/, '').trim(); + const newMemoryItem = `- ${processedText}`; + + const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER); + + if (headerIndex === -1) { + // Header not found, append header and then the entry + const separator = ensureNewlineSeparation(currentContent); + return ( + currentContent + + `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n` + ); + } else { + // Header found, find where to insert the new memory entry + const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length; + let endOfSectionIndex = currentContent.indexOf( + '\n## ', + startOfSectionContent, + ); + if (endOfSectionIndex === -1) { + endOfSectionIndex = currentContent.length; // End of file + } + + const beforeSectionMarker = currentContent + .substring(0, startOfSectionContent) + .trimEnd(); + let sectionContent = currentContent + .substring(startOfSectionContent, endOfSectionIndex) + .trimEnd(); + const afterSectionMarker = currentContent.substring(endOfSectionIndex); + + sectionContent += `\n${newMemoryItem}`; + return ( + `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() + + '\n' ); } +} - getDescription(_params: SaveMemoryParams): string { +class MemoryToolInvocation extends BaseToolInvocation< + SaveMemoryParams, + ToolResult +> { + private static readonly allowlist: Set = new Set(); + + getDescription(): string { const memoryFilePath = getGlobalMemoryFilePath(); return `in ${tildeifyPath(memoryFilePath)}`; } - /** - * Reads the current content of the memory file - */ - private async readMemoryFileContent(): Promise { - try { - return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); - } catch (err) { - const error = err as Error & { code?: string }; - if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; - return ''; - } - } - - /** - * Computes the new content that would result from adding a memory entry - */ - private computeNewContent(currentContent: string, fact: string): string { - let processedText = fact.trim(); - processedText = processedText.replace(/^(-+\s*)+/, '').trim(); - const newMemoryItem = `- ${processedText}`; - - const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER); - - if (headerIndex === -1) { - // Header not found, append header and then the entry - const separator = ensureNewlineSeparation(currentContent); - return ( - currentContent + - `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n` - ); - } else { - // Header found, find where to insert the new memory entry - const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length; - let endOfSectionIndex = currentContent.indexOf( - '\n## ', - startOfSectionContent, - ); - if (endOfSectionIndex === -1) { - endOfSectionIndex = currentContent.length; // End of file - } - - const beforeSectionMarker = currentContent - .substring(0, startOfSectionContent) - .trimEnd(); - let sectionContent = currentContent - .substring(startOfSectionContent, endOfSectionIndex) - .trimEnd(); - const afterSectionMarker = currentContent.substring(endOfSectionIndex); - - sectionContent += `\n${newMemoryItem}`; - return ( - `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() + - '\n' - ); - } - } - async shouldConfirmExecute( - params: SaveMemoryParams, _abortSignal: AbortSignal, ): Promise { const memoryFilePath = getGlobalMemoryFilePath(); const allowlistKey = memoryFilePath; - if (MemoryTool.allowlist.has(allowlistKey)) { + if (MemoryToolInvocation.allowlist.has(allowlistKey)) { return false; } - // Read current content of the memory file - const currentContent = await this.readMemoryFileContent(); - - // Calculate the new content that will be written to the memory file - const newContent = this.computeNewContent(currentContent, params.fact); + const currentContent = await readMemoryFileContent(); + const newContent = computeNewContent(currentContent, this.params.fact); const fileName = path.basename(memoryFilePath); const fileDiff = Diff.createPatch( @@ -226,13 +213,107 @@ export class MemoryTool newContent, onConfirm: async (outcome: ToolConfirmationOutcome) => { if (outcome === ToolConfirmationOutcome.ProceedAlways) { - MemoryTool.allowlist.add(allowlistKey); + MemoryToolInvocation.allowlist.add(allowlistKey); } }, }; return confirmationDetails; } + async execute(_signal: AbortSignal): Promise { + const { fact, modified_by_user, modified_content } = this.params; + + try { + if (modified_by_user && modified_content !== undefined) { + // User modified the content in external editor, write it directly + await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { + recursive: true, + }); + await fs.writeFile( + getGlobalMemoryFilePath(), + modified_content, + 'utf-8', + ); + const successMessage = `Okay, I've updated the memory file with your modifications.`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } else { + // Use the normal memory entry logic + await MemoryTool.performAddMemoryEntry( + fact, + getGlobalMemoryFilePath(), + { + readFile: fs.readFile, + writeFile: fs.writeFile, + mkdir: fs.mkdir, + }, + ); + const successMessage = `Okay, I've remembered that: "${fact}"`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + console.error( + `[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`, + ); + return { + llmContent: JSON.stringify({ + success: false, + error: `Failed to save memory. Detail: ${errorMessage}`, + }), + returnDisplay: `Error saving memory: ${errorMessage}`, + }; + } + } +} + +export class MemoryTool + extends BaseDeclarativeTool + implements ModifiableDeclarativeTool +{ + static readonly Name: string = memoryToolSchemaData.name!; + constructor() { + super( + MemoryTool.Name, + 'Save Memory', + memoryToolDescription, + Kind.Think, + memoryToolSchemaData.parametersJsonSchema as Record, + ); + } + + validateToolParams(params: SaveMemoryParams): string | null { + const errors = SchemaValidator.validate( + this.schema.parametersJsonSchema, + params, + ); + if (errors) { + return errors; + } + + if (params.fact.trim() === '') { + return 'Parameter "fact" must be a non-empty string.'; + } + + return null; + } + + protected createInvocation(params: SaveMemoryParams) { + return new MemoryToolInvocation(params); + } + static async performAddMemoryEntry( text: string, memoryFilePath: string, @@ -303,83 +384,14 @@ export class MemoryTool } } - async execute( - params: SaveMemoryParams, - _signal: AbortSignal, - ): Promise { - const { fact, modified_by_user, modified_content } = params; - - if (!fact || typeof fact !== 'string' || fact.trim() === '') { - const errorMessage = 'Parameter "fact" must be a non-empty string.'; - return { - llmContent: JSON.stringify({ success: false, error: errorMessage }), - returnDisplay: `Error: ${errorMessage}`, - }; - } - - try { - if (modified_by_user && modified_content !== undefined) { - // User modified the content in external editor, write it directly - await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { - recursive: true, - }); - await fs.writeFile( - getGlobalMemoryFilePath(), - modified_content, - 'utf-8', - ); - const successMessage = `Okay, I've updated the memory file with your modifications.`; - return { - llmContent: JSON.stringify({ - success: true, - message: successMessage, - }), - returnDisplay: successMessage, - }; - } else { - // Use the normal memory entry logic - await MemoryTool.performAddMemoryEntry( - fact, - getGlobalMemoryFilePath(), - { - readFile: fs.readFile, - writeFile: fs.writeFile, - mkdir: fs.mkdir, - }, - ); - const successMessage = `Okay, I've remembered that: "${fact}"`; - return { - llmContent: JSON.stringify({ - success: true, - message: successMessage, - }), - returnDisplay: successMessage, - }; - } - } catch (error) { - const errorMessage = - error instanceof Error ? error.message : String(error); - console.error( - `[MemoryTool] Error executing save_memory for fact "${fact}": ${errorMessage}`, - ); - return { - llmContent: JSON.stringify({ - success: false, - error: `Failed to save memory. Detail: ${errorMessage}`, - }), - returnDisplay: `Error saving memory: ${errorMessage}`, - }; - } - } - getModifyContext(_abortSignal: AbortSignal): ModifyContext { return { getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(), getCurrentContent: async (_params: SaveMemoryParams): Promise => - this.readMemoryFileContent(), + readMemoryFileContent(), getProposedContent: async (params: SaveMemoryParams): Promise => { - const currentContent = await this.readMemoryFileContent(); - return this.computeNewContent(currentContent, params.fact); + const currentContent = await readMemoryFileContent(); + return computeNewContent(currentContent, params.fact); }, createUpdatedParams: ( _oldContent: string,