diff --git a/packages/core/src/tools/memoryTool.test.ts b/packages/core/src/tools/memoryTool.test.ts index aff0cc2e..5a9b5f26 100644 --- a/packages/core/src/tools/memoryTool.test.ts +++ b/packages/core/src/tools/memoryTool.test.ts @@ -15,6 +15,7 @@ import { import * as fs from 'fs/promises'; import * as path from 'path'; import * as os from 'os'; +import { ToolConfirmationOutcome } from './tools.js'; // Mock dependencies vi.mock('fs/promises'); @@ -46,7 +47,7 @@ describe('MemoryTool', () => { }; beforeEach(() => { - vi.mocked(os.homedir).mockReturnValue('/mock/home'); + vi.mocked(os.homedir).mockReturnValue(path.join('/mock', 'home')); mockFsAdapter.readFile.mockReset(); mockFsAdapter.writeFile.mockReset().mockResolvedValue(undefined); mockFsAdapter.mkdir @@ -85,11 +86,15 @@ describe('MemoryTool', () => { }); describe('performAddMemoryEntry (static method)', () => { - const testFilePath = path.join( - '/mock/home', - '.gemini', - DEFAULT_CONTEXT_FILENAME, // Use the default for basic tests - ); + let testFilePath: string; + + beforeEach(() => { + testFilePath = path.join( + os.homedir(), + '.gemini', + DEFAULT_CONTEXT_FILENAME, + ); + }); it('should create section and save a fact if file does not exist', async () => { mockFsAdapter.readFile.mockRejectedValue({ code: 'ENOENT' }); // Simulate file not found @@ -206,7 +211,7 @@ describe('MemoryTool', () => { const result = await memoryTool.execute(params, mockAbortSignal); // Use getCurrentGeminiMdFilename for the default expectation before any setGeminiMdFilename calls in a test const expectedFilePath = path.join( - '/mock/home', + os.homedir(), '.gemini', getCurrentGeminiMdFilename(), // This will be DEFAULT_CONTEXT_FILENAME unless changed by a test ); @@ -262,4 +267,151 @@ describe('MemoryTool', () => { ); }); }); + + describe('shouldConfirmExecute', () => { + let memoryTool: MemoryTool; + + beforeEach(() => { + memoryTool = new MemoryTool(); + // Clear the allowlist before each test + (MemoryTool as unknown as { allowlist: Set }).allowlist.clear(); + // Mock fs.readFile to return empty string (file doesn't exist) + vi.mocked(fs.readFile).mockResolvedValue(''); + }); + + it('should return confirmation details when memory file is not allowlisted', async () => { + const params = { fact: 'Test fact' }; + const result = await memoryTool.shouldConfirmExecute( + params, + mockAbortSignal, + ); + + expect(result).toBeDefined(); + expect(result).not.toBe(false); + + if (result && result.type === 'edit') { + const expectedPath = path.join('~', '.gemini', 'GEMINI.md'); + expect(result.title).toBe(`Confirm Memory Save: ${expectedPath}`); + expect(result.fileName).toContain(path.join('mock', 'home', '.gemini')); + expect(result.fileName).toContain('GEMINI.md'); + expect(result.fileDiff).toContain('Index: GEMINI.md'); + expect(result.fileDiff).toContain('+## Gemini Added Memories'); + expect(result.fileDiff).toContain('+- Test fact'); + expect(result.originalContent).toBe(''); + expect(result.newContent).toContain('## Gemini Added Memories'); + expect(result.newContent).toContain('- Test fact'); + } + }); + + it('should return false when memory file is already allowlisted', async () => { + const params = { fact: 'Test fact' }; + const memoryFilePath = path.join( + os.homedir(), + '.gemini', + getCurrentGeminiMdFilename(), + ); + + // Add the memory file to the allowlist + (MemoryTool as unknown as { allowlist: Set }).allowlist.add( + memoryFilePath, + ); + + const result = await memoryTool.shouldConfirmExecute( + params, + mockAbortSignal, + ); + + expect(result).toBe(false); + }); + + it('should add memory file to allowlist when ProceedAlways is confirmed', async () => { + const params = { fact: 'Test fact' }; + const memoryFilePath = path.join( + os.homedir(), + '.gemini', + getCurrentGeminiMdFilename(), + ); + + const result = await memoryTool.shouldConfirmExecute( + params, + mockAbortSignal, + ); + + expect(result).toBeDefined(); + expect(result).not.toBe(false); + + if (result && result.type === 'edit') { + // Simulate the onConfirm callback + await result.onConfirm(ToolConfirmationOutcome.ProceedAlways); + + // Check that the memory file was added to the allowlist + expect( + (MemoryTool as unknown as { allowlist: Set }).allowlist.has( + memoryFilePath, + ), + ).toBe(true); + } + }); + + it('should not add memory file to allowlist when other outcomes are confirmed', async () => { + const params = { fact: 'Test fact' }; + const memoryFilePath = path.join( + os.homedir(), + '.gemini', + getCurrentGeminiMdFilename(), + ); + + const result = await memoryTool.shouldConfirmExecute( + params, + mockAbortSignal, + ); + + expect(result).toBeDefined(); + expect(result).not.toBe(false); + + if (result && result.type === 'edit') { + // Simulate the onConfirm callback with different outcomes + await result.onConfirm(ToolConfirmationOutcome.ProceedOnce); + expect( + (MemoryTool as unknown as { allowlist: Set }).allowlist.has( + memoryFilePath, + ), + ).toBe(false); + + await result.onConfirm(ToolConfirmationOutcome.Cancel); + expect( + (MemoryTool as unknown as { allowlist: Set }).allowlist.has( + memoryFilePath, + ), + ).toBe(false); + } + }); + + it('should handle existing memory file with content', async () => { + const params = { fact: 'New fact' }; + const existingContent = + 'Some existing content.\n\n## Gemini Added Memories\n- Old fact\n'; + + // Mock fs.readFile to return existing content + vi.mocked(fs.readFile).mockResolvedValue(existingContent); + + const result = await memoryTool.shouldConfirmExecute( + params, + mockAbortSignal, + ); + + expect(result).toBeDefined(); + expect(result).not.toBe(false); + + if (result && result.type === 'edit') { + const expectedPath = path.join('~', '.gemini', 'GEMINI.md'); + expect(result.title).toBe(`Confirm Memory Save: ${expectedPath}`); + expect(result.fileDiff).toContain('Index: GEMINI.md'); + expect(result.fileDiff).toContain('+- New fact'); + expect(result.originalContent).toBe(existingContent); + expect(result.newContent).toContain('- Old fact'); + expect(result.newContent).toContain('- New fact'); + } + }); + }); }); diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index f0f1e16b..96509f79 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -4,11 +4,21 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { BaseTool, Icon, ToolResult } from './tools.js'; +import { + BaseTool, + ToolResult, + ToolEditConfirmationDetails, + ToolConfirmationOutcome, + Icon, +} from './tools.js'; import { FunctionDeclaration, Type } from '@google/genai'; import * as fs from 'fs/promises'; import * as path from 'path'; import { homedir } from 'os'; +import * as Diff from 'diff'; +import { DEFAULT_DIFF_OPTIONS } from './diffOptions.js'; +import { tildeifyPath } from '../utils/paths.js'; +import { ModifiableTool, ModifyContext } from './modifiable-tool.js'; const memoryToolSchemaData: FunctionDeclaration = { name: 'save_memory', @@ -80,6 +90,8 @@ export function getAllGeminiMdFilenames(): string[] { interface SaveMemoryParams { fact: string; + modified_by_user?: boolean; + modified_content?: string; } function getGlobalMemoryFilePath(): string { @@ -98,7 +110,12 @@ function ensureNewlineSeparation(currentContent: string): string { return '\n\n'; } -export class MemoryTool extends BaseTool { +export class MemoryTool + extends BaseTool + implements ModifiableTool +{ + private static readonly allowlist: Set = new Set(); + static readonly Name: string = memoryToolSchemaData.name!; constructor() { super( @@ -110,6 +127,111 @@ export class MemoryTool extends BaseTool { ); } + getDescription(_params: SaveMemoryParams): string { + const memoryFilePath = getGlobalMemoryFilePath(); + return `in ${tildeifyPath(memoryFilePath)}`; + } + + /** + * Reads the current content of the memory file + */ + private async readMemoryFileContent(): Promise { + try { + return await fs.readFile(getGlobalMemoryFilePath(), 'utf-8'); + } catch (err) { + const error = err as Error & { code?: string }; + if (!(error instanceof Error) || error.code !== 'ENOENT') throw err; + return ''; + } + } + + /** + * Computes the new content that would result from adding a memory entry + */ + private computeNewContent(currentContent: string, fact: string): string { + let processedText = fact.trim(); + processedText = processedText.replace(/^(-+\s*)+/, '').trim(); + const newMemoryItem = `- ${processedText}`; + + const headerIndex = currentContent.indexOf(MEMORY_SECTION_HEADER); + + if (headerIndex === -1) { + // Header not found, append header and then the entry + const separator = ensureNewlineSeparation(currentContent); + return ( + currentContent + + `${separator}${MEMORY_SECTION_HEADER}\n${newMemoryItem}\n` + ); + } else { + // Header found, find where to insert the new memory entry + const startOfSectionContent = headerIndex + MEMORY_SECTION_HEADER.length; + let endOfSectionIndex = currentContent.indexOf( + '\n## ', + startOfSectionContent, + ); + if (endOfSectionIndex === -1) { + endOfSectionIndex = currentContent.length; // End of file + } + + const beforeSectionMarker = currentContent + .substring(0, startOfSectionContent) + .trimEnd(); + let sectionContent = currentContent + .substring(startOfSectionContent, endOfSectionIndex) + .trimEnd(); + const afterSectionMarker = currentContent.substring(endOfSectionIndex); + + sectionContent += `\n${newMemoryItem}`; + return ( + `${beforeSectionMarker}\n${sectionContent.trimStart()}\n${afterSectionMarker}`.trimEnd() + + '\n' + ); + } + } + + async shouldConfirmExecute( + params: SaveMemoryParams, + _abortSignal: AbortSignal, + ): Promise { + const memoryFilePath = getGlobalMemoryFilePath(); + const allowlistKey = memoryFilePath; + + if (MemoryTool.allowlist.has(allowlistKey)) { + return false; + } + + // Read current content of the memory file + const currentContent = await this.readMemoryFileContent(); + + // Calculate the new content that will be written to the memory file + const newContent = this.computeNewContent(currentContent, params.fact); + + const fileName = path.basename(memoryFilePath); + const fileDiff = Diff.createPatch( + fileName, + currentContent, + newContent, + 'Current', + 'Proposed', + DEFAULT_DIFF_OPTIONS, + ); + + const confirmationDetails: ToolEditConfirmationDetails = { + type: 'edit', + title: `Confirm Memory Save: ${tildeifyPath(memoryFilePath)}`, + fileName: memoryFilePath, + fileDiff, + originalContent: currentContent, + newContent, + onConfirm: async (outcome: ToolConfirmationOutcome) => { + if (outcome === ToolConfirmationOutcome.ProceedAlways) { + MemoryTool.allowlist.add(allowlistKey); + } + }, + }; + return confirmationDetails; + } + static async performAddMemoryEntry( text: string, memoryFilePath: string, @@ -184,7 +306,7 @@ export class MemoryTool extends BaseTool { params: SaveMemoryParams, _signal: AbortSignal, ): Promise { - const { fact } = params; + const { fact, modified_by_user, modified_content } = params; if (!fact || typeof fact !== 'string' || fact.trim() === '') { const errorMessage = 'Parameter "fact" must be a non-empty string.'; @@ -195,17 +317,44 @@ export class MemoryTool extends BaseTool { } try { - // Use the static method with actual fs promises - await MemoryTool.performAddMemoryEntry(fact, getGlobalMemoryFilePath(), { - readFile: fs.readFile, - writeFile: fs.writeFile, - mkdir: fs.mkdir, - }); - const successMessage = `Okay, I've remembered that: "${fact}"`; - return { - llmContent: JSON.stringify({ success: true, message: successMessage }), - returnDisplay: successMessage, - }; + if (modified_by_user && modified_content !== undefined) { + // User modified the content in external editor, write it directly + await fs.mkdir(path.dirname(getGlobalMemoryFilePath()), { + recursive: true, + }); + await fs.writeFile( + getGlobalMemoryFilePath(), + modified_content, + 'utf-8', + ); + const successMessage = `Okay, I've updated the memory file with your modifications.`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } else { + // Use the normal memory entry logic + await MemoryTool.performAddMemoryEntry( + fact, + getGlobalMemoryFilePath(), + { + readFile: fs.readFile, + writeFile: fs.writeFile, + mkdir: fs.mkdir, + }, + ); + const successMessage = `Okay, I've remembered that: "${fact}"`; + return { + llmContent: JSON.stringify({ + success: true, + message: successMessage, + }), + returnDisplay: successMessage, + }; + } } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error); @@ -221,4 +370,25 @@ export class MemoryTool extends BaseTool { }; } } + + getModifyContext(_abortSignal: AbortSignal): ModifyContext { + return { + getFilePath: (_params: SaveMemoryParams) => getGlobalMemoryFilePath(), + getCurrentContent: async (_params: SaveMemoryParams): Promise => + this.readMemoryFileContent(), + getProposedContent: async (params: SaveMemoryParams): Promise => { + const currentContent = await this.readMemoryFileContent(); + return this.computeNewContent(currentContent, params.fact); + }, + createUpdatedParams: ( + _oldContent: string, + modifiedProposedContent: string, + originalParams: SaveMemoryParams, + ): SaveMemoryParams => ({ + ...originalParams, + modified_by_user: true, + modified_content: modifiedProposedContent, + }), + }; + } }