diff --git a/packages/cli/src/services/CommandService.test.ts b/packages/cli/src/services/CommandService.test.ts index 084f603b..6ae52b52 100644 --- a/packages/cli/src/services/CommandService.test.ts +++ b/packages/cli/src/services/CommandService.test.ts @@ -26,6 +26,7 @@ import { mcpCommand } from '../ui/commands/mcpCommand.js'; import { editorCommand } from '../ui/commands/editorCommand.js'; import { bugCommand } from '../ui/commands/bugCommand.js'; import { quitCommand } from '../ui/commands/quitCommand.js'; +import { restoreCommand } from '../ui/commands/restoreCommand.js'; // Mock the command modules to isolate the service from the command implementations. vi.mock('../ui/commands/memoryCommand.js', () => ({ @@ -79,6 +80,9 @@ vi.mock('../ui/commands/bugCommand.js', () => ({ vi.mock('../ui/commands/quitCommand.js', () => ({ quitCommand: { name: 'quit', description: 'Mock Quit' }, })); +vi.mock('../ui/commands/restoreCommand.js', () => ({ + restoreCommand: vi.fn(), +})); describe('CommandService', () => { const subCommandLen = 17; @@ -87,8 +91,10 @@ describe('CommandService', () => { beforeEach(() => { mockConfig = { getIdeMode: vi.fn(), + getCheckpointingEnabled: vi.fn(), } as unknown as Mocked; vi.mocked(ideCommand).mockReturnValue(null); + vi.mocked(restoreCommand).mockReturnValue(null); }); describe('when using default production loader', () => { @@ -151,6 +157,20 @@ describe('CommandService', () => { expect(commandNames).toContain('quit'); }); + it('should include restore command when checkpointing is on', async () => { + mockConfig.getCheckpointingEnabled.mockReturnValue(true); + vi.mocked(restoreCommand).mockReturnValue({ + name: 'restore', + description: 'Mock Restore', + }); + await commandService.loadCommands(); + const tree = commandService.getCommands(); + + expect(tree.length).toBe(subCommandLen + 1); + const commandNames = tree.map((cmd) => cmd.name); + expect(commandNames).toContain('restore'); + }); + it('should overwrite any existing commands when called again', async () => { // Load once await commandService.loadCommands(); diff --git a/packages/cli/src/services/CommandService.ts b/packages/cli/src/services/CommandService.ts index 773f5b31..611b0a7b 100644 --- a/packages/cli/src/services/CommandService.ts +++ b/packages/cli/src/services/CommandService.ts @@ -24,6 +24,7 @@ import { compressCommand } from '../ui/commands/compressCommand.js'; import { ideCommand } from '../ui/commands/ideCommand.js'; import { bugCommand } from '../ui/commands/bugCommand.js'; import { quitCommand } from '../ui/commands/quitCommand.js'; +import { restoreCommand } from '../ui/commands/restoreCommand.js'; const loadBuiltInCommands = async ( config: Config | null, @@ -44,6 +45,7 @@ const loadBuiltInCommands = async ( memoryCommand, privacyCommand, quitCommand, + restoreCommand(config), statsCommand, themeCommand, toolsCommand, diff --git a/packages/cli/src/test-utils/mockCommandContext.ts b/packages/cli/src/test-utils/mockCommandContext.ts index 899d5747..3fb33b3f 100644 --- a/packages/cli/src/test-utils/mockCommandContext.ts +++ b/packages/cli/src/test-utils/mockCommandContext.ts @@ -46,6 +46,7 @@ export const createMockCommandContext = ( setDebugMessage: vi.fn(), pendingItem: null, setPendingItem: vi.fn(), + loadHistory: vi.fn(), }, session: { stats: { diff --git a/packages/cli/src/ui/commands/restoreCommand.test.ts b/packages/cli/src/ui/commands/restoreCommand.test.ts new file mode 100644 index 00000000..53cd7d18 --- /dev/null +++ b/packages/cli/src/ui/commands/restoreCommand.test.ts @@ -0,0 +1,237 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + Mocked, + Mock, +} from 'vitest'; +import * as fs from 'fs/promises'; +import { restoreCommand } from './restoreCommand.js'; +import { type CommandContext } from './types.js'; +import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; +import { Config, GitService } from '@google/gemini-cli-core'; + +vi.mock('fs/promises', () => ({ + readdir: vi.fn(), + readFile: vi.fn(), + mkdir: vi.fn(), +})); + +describe('restoreCommand', () => { + let mockContext: CommandContext; + let mockConfig: Config; + let mockGitService: GitService; + const mockFsPromises = fs as Mocked; + let mockSetHistory: ReturnType; + + beforeEach(() => { + mockSetHistory = vi.fn().mockResolvedValue(undefined); + mockGitService = { + restoreProjectFromSnapshot: vi.fn().mockResolvedValue(undefined), + } as unknown as GitService; + + mockConfig = { + getCheckpointingEnabled: vi.fn().mockReturnValue(true), + getProjectTempDir: vi.fn().mockReturnValue('/tmp/gemini'), + getGeminiClient: vi.fn().mockReturnValue({ + setHistory: mockSetHistory, + }), + } as unknown as Config; + + mockContext = createMockCommandContext({ + services: { + config: mockConfig, + git: mockGitService, + }, + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should return null if checkpointing is not enabled', () => { + (mockConfig.getCheckpointingEnabled as Mock).mockReturnValue(false); + const command = restoreCommand(mockConfig); + expect(command).toBeNull(); + }); + + it('should return the command if checkpointing is enabled', () => { + const command = restoreCommand(mockConfig); + expect(command).not.toBeNull(); + expect(command?.name).toBe('restore'); + expect(command?.description).toBeDefined(); + expect(command?.action).toBeDefined(); + expect(command?.completion).toBeDefined(); + }); + + describe('action', () => { + it('should return an error if temp dir is not found', async () => { + (mockConfig.getProjectTempDir as Mock).mockReturnValue(undefined); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, ''); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'Could not determine the .gemini directory path.', + }); + }); + + it('should inform when no checkpoints are found if no args are passed', async () => { + mockFsPromises.readdir.mockResolvedValue([]); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, ''); + expect(result).toEqual({ + type: 'message', + messageType: 'info', + content: 'No restorable tool calls found.', + }); + expect(mockFsPromises.mkdir).toHaveBeenCalledWith( + '/tmp/gemini/checkpoints', + { + recursive: true, + }, + ); + }); + + it('should list available checkpoints if no args are passed', async () => { + mockFsPromises.readdir.mockResolvedValue([ + 'test1.json', + 'test2.json', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, ''); + expect(result).toEqual({ + type: 'message', + messageType: 'info', + content: 'Available tool calls to restore:\n\ntest1\ntest2', + }); + }); + + it('should return an error if the specified file is not found', async () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockFsPromises.readdir.mockResolvedValue(['test1.json'] as any); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, 'test2'); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: 'File not found: test2.json', + }); + }); + + it('should handle file read errors gracefully', async () => { + const readError = new Error('Read failed'); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockFsPromises.readdir.mockResolvedValue(['test1.json'] as any); + mockFsPromises.readFile.mockRejectedValue(readError); + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, 'test1'); + expect(result).toEqual({ + type: 'message', + messageType: 'error', + content: `Could not read restorable tool calls. This is the error: ${readError}`, + }); + }); + + it('should restore a tool call and project state', async () => { + const toolCallData = { + history: [{ type: 'user', text: 'do a thing' }], + clientHistory: [{ role: 'user', parts: [{ text: 'do a thing' }] }], + commitHash: 'abcdef123', + toolCall: { name: 'run_shell_command', args: 'ls' }, + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockFsPromises.readdir.mockResolvedValue(['my-checkpoint.json'] as any); + mockFsPromises.readFile.mockResolvedValue(JSON.stringify(toolCallData)); + + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, 'my-checkpoint'); + + // Check history restoration + expect(mockContext.ui.loadHistory).toHaveBeenCalledWith( + toolCallData.history, + ); + expect(mockSetHistory).toHaveBeenCalledWith(toolCallData.clientHistory); + + // Check git restoration + expect(mockGitService.restoreProjectFromSnapshot).toHaveBeenCalledWith( + toolCallData.commitHash, + ); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: 'info', + text: 'Restored project to the state before the tool call.', + }, + expect.any(Number), + ); + + // Check returned action + expect(result).toEqual({ + type: 'tool', + toolName: 'run_shell_command', + toolArgs: 'ls', + }); + }); + + it('should restore even if only toolCall is present', async () => { + const toolCallData = { + toolCall: { name: 'run_shell_command', args: 'ls' }, + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + mockFsPromises.readdir.mockResolvedValue(['my-checkpoint.json'] as any); + mockFsPromises.readFile.mockResolvedValue(JSON.stringify(toolCallData)); + + const command = restoreCommand(mockConfig); + const result = await command?.action?.(mockContext, 'my-checkpoint'); + + expect(mockContext.ui.loadHistory).not.toHaveBeenCalled(); + expect(mockSetHistory).not.toHaveBeenCalled(); + expect(mockGitService.restoreProjectFromSnapshot).not.toHaveBeenCalled(); + + expect(result).toEqual({ + type: 'tool', + toolName: 'run_shell_command', + toolArgs: 'ls', + }); + }); + }); + + describe('completion', () => { + it('should return an empty array if temp dir is not found', async () => { + (mockConfig.getProjectTempDir as Mock).mockReturnValue(undefined); + const command = restoreCommand(mockConfig); + const result = await command?.completion?.(mockContext, ''); + expect(result).toEqual([]); + }); + + it('should return an empty array on readdir error', async () => { + mockFsPromises.readdir.mockRejectedValue(new Error('ENOENT')); + const command = restoreCommand(mockConfig); + const result = await command?.completion?.(mockContext, ''); + expect(result).toEqual([]); + }); + + it('should return a list of checkpoint names', async () => { + mockFsPromises.readdir.mockResolvedValue([ + 'test1.json', + 'test2.json', + 'not-a-checkpoint.txt', + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any); + const command = restoreCommand(mockConfig); + const result = await command?.completion?.(mockContext, ''); + expect(result).toEqual(['test1', 'test2']); + }); + }); +}); diff --git a/packages/cli/src/ui/commands/restoreCommand.ts b/packages/cli/src/ui/commands/restoreCommand.ts new file mode 100644 index 00000000..3d744189 --- /dev/null +++ b/packages/cli/src/ui/commands/restoreCommand.ts @@ -0,0 +1,155 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as fs from 'fs/promises'; +import path from 'path'; +import { + type CommandContext, + type SlashCommand, + type SlashCommandActionReturn, +} from './types.js'; +import { Config } from '@google/gemini-cli-core'; + +async function restoreAction( + context: CommandContext, + args: string, +): Promise { + const { services, ui } = context; + const { config, git: gitService } = services; + const { addItem, loadHistory } = ui; + + const checkpointDir = config?.getProjectTempDir() + ? path.join(config.getProjectTempDir(), 'checkpoints') + : undefined; + + if (!checkpointDir) { + return { + type: 'message', + messageType: 'error', + content: 'Could not determine the .gemini directory path.', + }; + } + + try { + // Ensure the directory exists before trying to read it. + await fs.mkdir(checkpointDir, { recursive: true }); + const files = await fs.readdir(checkpointDir); + const jsonFiles = files.filter((file) => file.endsWith('.json')); + + if (!args) { + if (jsonFiles.length === 0) { + return { + type: 'message', + messageType: 'info', + content: 'No restorable tool calls found.', + }; + } + const truncatedFiles = jsonFiles.map((file) => { + const components = file.split('.'); + if (components.length <= 1) { + return file; + } + components.pop(); + return components.join('.'); + }); + const fileList = truncatedFiles.join('\n'); + return { + type: 'message', + messageType: 'info', + content: `Available tool calls to restore:\n\n${fileList}`, + }; + } + + const selectedFile = args.endsWith('.json') ? args : `${args}.json`; + + if (!jsonFiles.includes(selectedFile)) { + return { + type: 'message', + messageType: 'error', + content: `File not found: ${selectedFile}`, + }; + } + + const filePath = path.join(checkpointDir, selectedFile); + const data = await fs.readFile(filePath, 'utf-8'); + const toolCallData = JSON.parse(data); + + if (toolCallData.history) { + if (!loadHistory) { + // This should not happen + return { + type: 'message', + messageType: 'error', + content: 'loadHistory function is not available.', + }; + } + loadHistory(toolCallData.history); + } + + if (toolCallData.clientHistory) { + await config?.getGeminiClient()?.setHistory(toolCallData.clientHistory); + } + + if (toolCallData.commitHash) { + await gitService?.restoreProjectFromSnapshot(toolCallData.commitHash); + addItem( + { + type: 'info', + text: 'Restored project to the state before the tool call.', + }, + Date.now(), + ); + } + + return { + type: 'tool', + toolName: toolCallData.toolCall.name, + toolArgs: toolCallData.toolCall.args, + }; + } catch (error) { + return { + type: 'message', + messageType: 'error', + content: `Could not read restorable tool calls. This is the error: ${error}`, + }; + } +} + +async function completion( + context: CommandContext, + _partialArg: string, +): Promise { + const { services } = context; + const { config } = services; + const checkpointDir = config?.getProjectTempDir() + ? path.join(config.getProjectTempDir(), 'checkpoints') + : undefined; + if (!checkpointDir) { + return []; + } + try { + const files = await fs.readdir(checkpointDir); + return files + .filter((file) => file.endsWith('.json')) + .map((file) => file.replace('.json', '')); + } catch (_err) { + return []; + } +} + +export const restoreCommand = (config: Config | null): SlashCommand | null => { + if (!config?.getCheckpointingEnabled()) { + return null; + } + + return { + name: 'restore', + description: + 'Restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested', + action: restoreAction, + completion, + }; +}; diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index d3d5ee8a..51b66fb4 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -41,6 +41,12 @@ export interface CommandContext { * @param item The history item to display as pending, or `null` to clear. */ setPendingItem: (item: HistoryItemWithoutId | null) => void; + /** + * Loads a new set of history items, replacing the current history. + * + * @param history The array of history items to load. + */ + loadHistory: UseHistoryManagerReturn['loadHistory']; }; // Session-specific data session: { diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 125d051e..295d1c50 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -18,8 +18,6 @@ import { HistoryItem, SlashCommandProcessorResult, } from '../types.js'; -import { promises as fs } from 'fs'; -import path from 'path'; import { LoadedSettings } from '../../config/settings.js'; import { type CommandContext, @@ -155,6 +153,7 @@ export const useSlashCommandProcessor = ( console.clear(); refreshStatic(); }, + loadHistory, setDebugMessage: onDebugMessage, pendingItem: pendingCompressionItemRef.current, setPendingItem: setPendingCompressionItem, @@ -168,6 +167,7 @@ export const useSlashCommandProcessor = ( settings, gitService, logger, + loadHistory, addItem, clearItems, refreshStatic, @@ -203,128 +203,8 @@ export const useSlashCommandProcessor = ( }, ]; - if (config?.getCheckpointingEnabled()) { - commands.push({ - name: 'restore', - description: - 'restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested', - completion: async () => { - const checkpointDir = config?.getProjectTempDir() - ? path.join(config.getProjectTempDir(), 'checkpoints') - : undefined; - if (!checkpointDir) { - return []; - } - try { - const files = await fs.readdir(checkpointDir); - return files - .filter((file) => file.endsWith('.json')) - .map((file) => file.replace('.json', '')); - } catch (_err) { - return []; - } - }, - action: async (_mainCommand, subCommand, _args) => { - const checkpointDir = config?.getProjectTempDir() - ? path.join(config.getProjectTempDir(), 'checkpoints') - : undefined; - - if (!checkpointDir) { - addMessage({ - type: MessageType.ERROR, - content: 'Could not determine the .gemini directory path.', - timestamp: new Date(), - }); - return; - } - - try { - // Ensure the directory exists before trying to read it. - await fs.mkdir(checkpointDir, { recursive: true }); - const files = await fs.readdir(checkpointDir); - const jsonFiles = files.filter((file) => file.endsWith('.json')); - - if (!subCommand) { - if (jsonFiles.length === 0) { - addMessage({ - type: MessageType.INFO, - content: 'No restorable tool calls found.', - timestamp: new Date(), - }); - return; - } - const truncatedFiles = jsonFiles.map((file) => { - const components = file.split('.'); - if (components.length <= 1) { - return file; - } - components.pop(); - return components.join('.'); - }); - const fileList = truncatedFiles.join('\n'); - addMessage({ - type: MessageType.INFO, - content: `Available tool calls to restore:\n\n${fileList}`, - timestamp: new Date(), - }); - return; - } - - const selectedFile = subCommand.endsWith('.json') - ? subCommand - : `${subCommand}.json`; - - if (!jsonFiles.includes(selectedFile)) { - addMessage({ - type: MessageType.ERROR, - content: `File not found: ${selectedFile}`, - timestamp: new Date(), - }); - return; - } - - const filePath = path.join(checkpointDir, selectedFile); - const data = await fs.readFile(filePath, 'utf-8'); - const toolCallData = JSON.parse(data); - - if (toolCallData.history) { - loadHistory(toolCallData.history); - } - - if (toolCallData.clientHistory) { - await config - ?.getGeminiClient() - ?.setHistory(toolCallData.clientHistory); - } - - if (toolCallData.commitHash) { - await gitService?.restoreProjectFromSnapshot( - toolCallData.commitHash, - ); - addMessage({ - type: MessageType.INFO, - content: `Restored project to the state before the tool call.`, - timestamp: new Date(), - }); - } - - return { - type: 'tool', - toolName: toolCallData.toolCall.name, - toolArgs: toolCallData.toolCall.args, - }; - } catch (error) { - addMessage({ - type: MessageType.ERROR, - content: `Could not read restorable tool calls. This is the error: ${error}`, - timestamp: new Date(), - }); - } - }, - }); - } return commands; - }, [addMessage, toggleCorgiMode, config, gitService, loadHistory]); + }, [toggleCorgiMode]); const handleSlashCommand = useCallback( async (