migrate restore command (#4388)

This commit is contained in:
Abhi 2025-07-17 19:23:17 -04:00 committed by GitHub
parent f0dc9690b7
commit 5df6c9fb66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 424 additions and 123 deletions

View File

@ -26,6 +26,7 @@ import { mcpCommand } from '../ui/commands/mcpCommand.js';
import { editorCommand } from '../ui/commands/editorCommand.js';
import { bugCommand } from '../ui/commands/bugCommand.js';
import { quitCommand } from '../ui/commands/quitCommand.js';
import { restoreCommand } from '../ui/commands/restoreCommand.js';
// Mock the command modules to isolate the service from the command implementations.
vi.mock('../ui/commands/memoryCommand.js', () => ({
@ -79,6 +80,9 @@ vi.mock('../ui/commands/bugCommand.js', () => ({
vi.mock('../ui/commands/quitCommand.js', () => ({
quitCommand: { name: 'quit', description: 'Mock Quit' },
}));
vi.mock('../ui/commands/restoreCommand.js', () => ({
restoreCommand: vi.fn(),
}));
describe('CommandService', () => {
const subCommandLen = 17;
@ -87,8 +91,10 @@ describe('CommandService', () => {
beforeEach(() => {
mockConfig = {
getIdeMode: vi.fn(),
getCheckpointingEnabled: vi.fn(),
} as unknown as Mocked<Config>;
vi.mocked(ideCommand).mockReturnValue(null);
vi.mocked(restoreCommand).mockReturnValue(null);
});
describe('when using default production loader', () => {
@ -151,6 +157,20 @@ describe('CommandService', () => {
expect(commandNames).toContain('quit');
});
it('should include restore command when checkpointing is on', async () => {
mockConfig.getCheckpointingEnabled.mockReturnValue(true);
vi.mocked(restoreCommand).mockReturnValue({
name: 'restore',
description: 'Mock Restore',
});
await commandService.loadCommands();
const tree = commandService.getCommands();
expect(tree.length).toBe(subCommandLen + 1);
const commandNames = tree.map((cmd) => cmd.name);
expect(commandNames).toContain('restore');
});
it('should overwrite any existing commands when called again', async () => {
// Load once
await commandService.loadCommands();

View File

@ -24,6 +24,7 @@ import { compressCommand } from '../ui/commands/compressCommand.js';
import { ideCommand } from '../ui/commands/ideCommand.js';
import { bugCommand } from '../ui/commands/bugCommand.js';
import { quitCommand } from '../ui/commands/quitCommand.js';
import { restoreCommand } from '../ui/commands/restoreCommand.js';
const loadBuiltInCommands = async (
config: Config | null,
@ -44,6 +45,7 @@ const loadBuiltInCommands = async (
memoryCommand,
privacyCommand,
quitCommand,
restoreCommand(config),
statsCommand,
themeCommand,
toolsCommand,

View File

@ -46,6 +46,7 @@ export const createMockCommandContext = (
setDebugMessage: vi.fn(),
pendingItem: null,
setPendingItem: vi.fn(),
loadHistory: vi.fn(),
},
session: {
stats: {

View File

@ -0,0 +1,237 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
vi,
describe,
it,
expect,
beforeEach,
afterEach,
Mocked,
Mock,
} from 'vitest';
import * as fs from 'fs/promises';
import { restoreCommand } from './restoreCommand.js';
import { type CommandContext } from './types.js';
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
import { Config, GitService } from '@google/gemini-cli-core';
vi.mock('fs/promises', () => ({
readdir: vi.fn(),
readFile: vi.fn(),
mkdir: vi.fn(),
}));
describe('restoreCommand', () => {
let mockContext: CommandContext;
let mockConfig: Config;
let mockGitService: GitService;
const mockFsPromises = fs as Mocked<typeof fs>;
let mockSetHistory: ReturnType<typeof vi.fn>;
beforeEach(() => {
mockSetHistory = vi.fn().mockResolvedValue(undefined);
mockGitService = {
restoreProjectFromSnapshot: vi.fn().mockResolvedValue(undefined),
} as unknown as GitService;
mockConfig = {
getCheckpointingEnabled: vi.fn().mockReturnValue(true),
getProjectTempDir: vi.fn().mockReturnValue('/tmp/gemini'),
getGeminiClient: vi.fn().mockReturnValue({
setHistory: mockSetHistory,
}),
} as unknown as Config;
mockContext = createMockCommandContext({
services: {
config: mockConfig,
git: mockGitService,
},
});
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should return null if checkpointing is not enabled', () => {
(mockConfig.getCheckpointingEnabled as Mock).mockReturnValue(false);
const command = restoreCommand(mockConfig);
expect(command).toBeNull();
});
it('should return the command if checkpointing is enabled', () => {
const command = restoreCommand(mockConfig);
expect(command).not.toBeNull();
expect(command?.name).toBe('restore');
expect(command?.description).toBeDefined();
expect(command?.action).toBeDefined();
expect(command?.completion).toBeDefined();
});
describe('action', () => {
it('should return an error if temp dir is not found', async () => {
(mockConfig.getProjectTempDir as Mock).mockReturnValue(undefined);
const command = restoreCommand(mockConfig);
const result = await command?.action?.(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'Could not determine the .gemini directory path.',
});
});
it('should inform when no checkpoints are found if no args are passed', async () => {
mockFsPromises.readdir.mockResolvedValue([]);
const command = restoreCommand(mockConfig);
const result = await command?.action?.(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'info',
content: 'No restorable tool calls found.',
});
expect(mockFsPromises.mkdir).toHaveBeenCalledWith(
'/tmp/gemini/checkpoints',
{
recursive: true,
},
);
});
it('should list available checkpoints if no args are passed', async () => {
mockFsPromises.readdir.mockResolvedValue([
'test1.json',
'test2.json',
// eslint-disable-next-line @typescript-eslint/no-explicit-any
] as any);
const command = restoreCommand(mockConfig);
const result = await command?.action?.(mockContext, '');
expect(result).toEqual({
type: 'message',
messageType: 'info',
content: 'Available tool calls to restore:\n\ntest1\ntest2',
});
});
it('should return an error if the specified file is not found', async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
mockFsPromises.readdir.mockResolvedValue(['test1.json'] as any);
const command = restoreCommand(mockConfig);
const result = await command?.action?.(mockContext, 'test2');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: 'File not found: test2.json',
});
});
it('should handle file read errors gracefully', async () => {
const readError = new Error('Read failed');
// eslint-disable-next-line @typescript-eslint/no-explicit-any
mockFsPromises.readdir.mockResolvedValue(['test1.json'] as any);
mockFsPromises.readFile.mockRejectedValue(readError);
const command = restoreCommand(mockConfig);
const result = await command?.action?.(mockContext, 'test1');
expect(result).toEqual({
type: 'message',
messageType: 'error',
content: `Could not read restorable tool calls. This is the error: ${readError}`,
});
});
it('should restore a tool call and project state', async () => {
const toolCallData = {
history: [{ type: 'user', text: 'do a thing' }],
clientHistory: [{ role: 'user', parts: [{ text: 'do a thing' }] }],
commitHash: 'abcdef123',
toolCall: { name: 'run_shell_command', args: 'ls' },
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
mockFsPromises.readdir.mockResolvedValue(['my-checkpoint.json'] as any);
mockFsPromises.readFile.mockResolvedValue(JSON.stringify(toolCallData));
const command = restoreCommand(mockConfig);
const result = await command?.action?.(mockContext, 'my-checkpoint');
// Check history restoration
expect(mockContext.ui.loadHistory).toHaveBeenCalledWith(
toolCallData.history,
);
expect(mockSetHistory).toHaveBeenCalledWith(toolCallData.clientHistory);
// Check git restoration
expect(mockGitService.restoreProjectFromSnapshot).toHaveBeenCalledWith(
toolCallData.commitHash,
);
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
{
type: 'info',
text: 'Restored project to the state before the tool call.',
},
expect.any(Number),
);
// Check returned action
expect(result).toEqual({
type: 'tool',
toolName: 'run_shell_command',
toolArgs: 'ls',
});
});
it('should restore even if only toolCall is present', async () => {
const toolCallData = {
toolCall: { name: 'run_shell_command', args: 'ls' },
};
// eslint-disable-next-line @typescript-eslint/no-explicit-any
mockFsPromises.readdir.mockResolvedValue(['my-checkpoint.json'] as any);
mockFsPromises.readFile.mockResolvedValue(JSON.stringify(toolCallData));
const command = restoreCommand(mockConfig);
const result = await command?.action?.(mockContext, 'my-checkpoint');
expect(mockContext.ui.loadHistory).not.toHaveBeenCalled();
expect(mockSetHistory).not.toHaveBeenCalled();
expect(mockGitService.restoreProjectFromSnapshot).not.toHaveBeenCalled();
expect(result).toEqual({
type: 'tool',
toolName: 'run_shell_command',
toolArgs: 'ls',
});
});
});
describe('completion', () => {
it('should return an empty array if temp dir is not found', async () => {
(mockConfig.getProjectTempDir as Mock).mockReturnValue(undefined);
const command = restoreCommand(mockConfig);
const result = await command?.completion?.(mockContext, '');
expect(result).toEqual([]);
});
it('should return an empty array on readdir error', async () => {
mockFsPromises.readdir.mockRejectedValue(new Error('ENOENT'));
const command = restoreCommand(mockConfig);
const result = await command?.completion?.(mockContext, '');
expect(result).toEqual([]);
});
it('should return a list of checkpoint names', async () => {
mockFsPromises.readdir.mockResolvedValue([
'test1.json',
'test2.json',
'not-a-checkpoint.txt',
// eslint-disable-next-line @typescript-eslint/no-explicit-any
] as any);
const command = restoreCommand(mockConfig);
const result = await command?.completion?.(mockContext, '');
expect(result).toEqual(['test1', 'test2']);
});
});
});

View File

@ -0,0 +1,155 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'fs/promises';
import path from 'path';
import {
type CommandContext,
type SlashCommand,
type SlashCommandActionReturn,
} from './types.js';
import { Config } from '@google/gemini-cli-core';
async function restoreAction(
context: CommandContext,
args: string,
): Promise<void | SlashCommandActionReturn> {
const { services, ui } = context;
const { config, git: gitService } = services;
const { addItem, loadHistory } = ui;
const checkpointDir = config?.getProjectTempDir()
? path.join(config.getProjectTempDir(), 'checkpoints')
: undefined;
if (!checkpointDir) {
return {
type: 'message',
messageType: 'error',
content: 'Could not determine the .gemini directory path.',
};
}
try {
// Ensure the directory exists before trying to read it.
await fs.mkdir(checkpointDir, { recursive: true });
const files = await fs.readdir(checkpointDir);
const jsonFiles = files.filter((file) => file.endsWith('.json'));
if (!args) {
if (jsonFiles.length === 0) {
return {
type: 'message',
messageType: 'info',
content: 'No restorable tool calls found.',
};
}
const truncatedFiles = jsonFiles.map((file) => {
const components = file.split('.');
if (components.length <= 1) {
return file;
}
components.pop();
return components.join('.');
});
const fileList = truncatedFiles.join('\n');
return {
type: 'message',
messageType: 'info',
content: `Available tool calls to restore:\n\n${fileList}`,
};
}
const selectedFile = args.endsWith('.json') ? args : `${args}.json`;
if (!jsonFiles.includes(selectedFile)) {
return {
type: 'message',
messageType: 'error',
content: `File not found: ${selectedFile}`,
};
}
const filePath = path.join(checkpointDir, selectedFile);
const data = await fs.readFile(filePath, 'utf-8');
const toolCallData = JSON.parse(data);
if (toolCallData.history) {
if (!loadHistory) {
// This should not happen
return {
type: 'message',
messageType: 'error',
content: 'loadHistory function is not available.',
};
}
loadHistory(toolCallData.history);
}
if (toolCallData.clientHistory) {
await config?.getGeminiClient()?.setHistory(toolCallData.clientHistory);
}
if (toolCallData.commitHash) {
await gitService?.restoreProjectFromSnapshot(toolCallData.commitHash);
addItem(
{
type: 'info',
text: 'Restored project to the state before the tool call.',
},
Date.now(),
);
}
return {
type: 'tool',
toolName: toolCallData.toolCall.name,
toolArgs: toolCallData.toolCall.args,
};
} catch (error) {
return {
type: 'message',
messageType: 'error',
content: `Could not read restorable tool calls. This is the error: ${error}`,
};
}
}
async function completion(
context: CommandContext,
_partialArg: string,
): Promise<string[]> {
const { services } = context;
const { config } = services;
const checkpointDir = config?.getProjectTempDir()
? path.join(config.getProjectTempDir(), 'checkpoints')
: undefined;
if (!checkpointDir) {
return [];
}
try {
const files = await fs.readdir(checkpointDir);
return files
.filter((file) => file.endsWith('.json'))
.map((file) => file.replace('.json', ''));
} catch (_err) {
return [];
}
}
export const restoreCommand = (config: Config | null): SlashCommand | null => {
if (!config?.getCheckpointingEnabled()) {
return null;
}
return {
name: 'restore',
description:
'Restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested',
action: restoreAction,
completion,
};
};

View File

@ -41,6 +41,12 @@ export interface CommandContext {
* @param item The history item to display as pending, or `null` to clear.
*/
setPendingItem: (item: HistoryItemWithoutId | null) => void;
/**
* Loads a new set of history items, replacing the current history.
*
* @param history The array of history items to load.
*/
loadHistory: UseHistoryManagerReturn['loadHistory'];
};
// Session-specific data
session: {

View File

@ -18,8 +18,6 @@ import {
HistoryItem,
SlashCommandProcessorResult,
} from '../types.js';
import { promises as fs } from 'fs';
import path from 'path';
import { LoadedSettings } from '../../config/settings.js';
import {
type CommandContext,
@ -155,6 +153,7 @@ export const useSlashCommandProcessor = (
console.clear();
refreshStatic();
},
loadHistory,
setDebugMessage: onDebugMessage,
pendingItem: pendingCompressionItemRef.current,
setPendingItem: setPendingCompressionItem,
@ -168,6 +167,7 @@ export const useSlashCommandProcessor = (
settings,
gitService,
logger,
loadHistory,
addItem,
clearItems,
refreshStatic,
@ -203,128 +203,8 @@ export const useSlashCommandProcessor = (
},
];
if (config?.getCheckpointingEnabled()) {
commands.push({
name: 'restore',
description:
'restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested',
completion: async () => {
const checkpointDir = config?.getProjectTempDir()
? path.join(config.getProjectTempDir(), 'checkpoints')
: undefined;
if (!checkpointDir) {
return [];
}
try {
const files = await fs.readdir(checkpointDir);
return files
.filter((file) => file.endsWith('.json'))
.map((file) => file.replace('.json', ''));
} catch (_err) {
return [];
}
},
action: async (_mainCommand, subCommand, _args) => {
const checkpointDir = config?.getProjectTempDir()
? path.join(config.getProjectTempDir(), 'checkpoints')
: undefined;
if (!checkpointDir) {
addMessage({
type: MessageType.ERROR,
content: 'Could not determine the .gemini directory path.',
timestamp: new Date(),
});
return;
}
try {
// Ensure the directory exists before trying to read it.
await fs.mkdir(checkpointDir, { recursive: true });
const files = await fs.readdir(checkpointDir);
const jsonFiles = files.filter((file) => file.endsWith('.json'));
if (!subCommand) {
if (jsonFiles.length === 0) {
addMessage({
type: MessageType.INFO,
content: 'No restorable tool calls found.',
timestamp: new Date(),
});
return;
}
const truncatedFiles = jsonFiles.map((file) => {
const components = file.split('.');
if (components.length <= 1) {
return file;
}
components.pop();
return components.join('.');
});
const fileList = truncatedFiles.join('\n');
addMessage({
type: MessageType.INFO,
content: `Available tool calls to restore:\n\n${fileList}`,
timestamp: new Date(),
});
return;
}
const selectedFile = subCommand.endsWith('.json')
? subCommand
: `${subCommand}.json`;
if (!jsonFiles.includes(selectedFile)) {
addMessage({
type: MessageType.ERROR,
content: `File not found: ${selectedFile}`,
timestamp: new Date(),
});
return;
}
const filePath = path.join(checkpointDir, selectedFile);
const data = await fs.readFile(filePath, 'utf-8');
const toolCallData = JSON.parse(data);
if (toolCallData.history) {
loadHistory(toolCallData.history);
}
if (toolCallData.clientHistory) {
await config
?.getGeminiClient()
?.setHistory(toolCallData.clientHistory);
}
if (toolCallData.commitHash) {
await gitService?.restoreProjectFromSnapshot(
toolCallData.commitHash,
);
addMessage({
type: MessageType.INFO,
content: `Restored project to the state before the tool call.`,
timestamp: new Date(),
});
}
return {
type: 'tool',
toolName: toolCallData.toolCall.name,
toolArgs: toolCallData.toolCall.args,
};
} catch (error) {
addMessage({
type: MessageType.ERROR,
content: `Could not read restorable tool calls. This is the error: ${error}`,
timestamp: new Date(),
});
}
},
});
}
return commands;
}, [addMessage, toggleCorgiMode, config, gitService, loadHistory]);
}, [toggleCorgiMode]);
const handleSlashCommand = useCallback(
async (