diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 9f18fe55..e3c77ad0 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -39,6 +39,7 @@ import { AuthInProgress } from './components/AuthInProgress.js'; import { EditorSettingsDialog } from './components/EditorSettingsDialog.js'; import { FolderTrustDialog } from './components/FolderTrustDialog.js'; import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js'; +import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js'; import { Colors } from './colors.js'; import { loadHierarchicalGeminiMemory } from '../config/config.js'; import { LoadedSettings, SettingScope } from '../config/settings.js'; @@ -488,6 +489,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { pendingHistoryItems: pendingSlashCommandHistoryItems, commandContext, shellConfirmationRequest, + confirmationRequest, } = useSlashCommandProcessor( config, settings, @@ -912,6 +914,21 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { ) : shellConfirmationRequest ? ( + ) : confirmationRequest ? ( + + {confirmationRequest.prompt} + + { + confirmationRequest.onConfirm(value); + }} + /> + + ) : isThemeDialogOpen ? ( {themeError && ( diff --git a/packages/cli/src/ui/commands/chatCommand.test.ts b/packages/cli/src/ui/commands/chatCommand.test.ts index aad0897c..ccdfd4b2 100644 --- a/packages/cli/src/ui/commands/chatCommand.test.ts +++ b/packages/cli/src/ui/commands/chatCommand.test.ts @@ -168,8 +168,12 @@ describe('chatCommand', () => { describe('save subcommand', () => { let saveCommand: SlashCommand; const tag = 'my-tag'; + let mockCheckpointExists: ReturnType; + beforeEach(() => { saveCommand = getSubCommand('save'); + mockCheckpointExists = vi.fn().mockResolvedValue(false); + mockContext.services.logger.checkpointExists = mockCheckpointExists; }); it('should return an error if tag is missing', async () => { @@ -191,7 +195,7 @@ describe('chatCommand', () => { }); }); - it('should save the conversation', async () => { + it('should save the conversation if checkpoint does not exist', async () => { const history: HistoryItemWithoutId[] = [ { type: 'user', @@ -199,8 +203,52 @@ describe('chatCommand', () => { }, ]; mockGetHistory.mockReturnValue(history); + mockCheckpointExists.mockResolvedValue(false); + const result = await saveCommand?.action?.(mockContext, tag); + expect(mockCheckpointExists).toHaveBeenCalledWith(tag); + expect(mockSaveCheckpoint).toHaveBeenCalledWith(history, tag); + expect(result).toEqual({ + type: 'message', + messageType: 'info', + content: `Conversation checkpoint saved with tag: ${tag}.`, + }); + }); + + it('should return confirm_action if checkpoint already exists', async () => { + mockCheckpointExists.mockResolvedValue(true); + mockContext.invocation = { + raw: `/chat save ${tag}`, + name: 'save', + args: tag, + }; + + const result = await saveCommand?.action?.(mockContext, tag); + + expect(mockCheckpointExists).toHaveBeenCalledWith(tag); + expect(mockSaveCheckpoint).not.toHaveBeenCalled(); + expect(result).toMatchObject({ + type: 'confirm_action', + originalInvocation: { raw: `/chat save ${tag}` }, + }); + // Check that prompt is a React element + expect(result).toHaveProperty('prompt'); + }); + + it('should save the conversation if overwrite is confirmed', async () => { + const history: HistoryItemWithoutId[] = [ + { + type: 'user', + text: 'hello', + }, + ]; + mockGetHistory.mockReturnValue(history); + mockContext.overwriteConfirmed = true; + + const result = await saveCommand?.action?.(mockContext, tag); + + expect(mockCheckpointExists).not.toHaveBeenCalled(); // Should skip existence check expect(mockSaveCheckpoint).toHaveBeenCalledWith(history, tag); expect(result).toEqual({ type: 'message', diff --git a/packages/cli/src/ui/commands/chatCommand.ts b/packages/cli/src/ui/commands/chatCommand.ts index a5fa13da..56eebe1a 100644 --- a/packages/cli/src/ui/commands/chatCommand.ts +++ b/packages/cli/src/ui/commands/chatCommand.ts @@ -5,11 +5,15 @@ */ import * as fsPromises from 'fs/promises'; +import React from 'react'; +import { Text } from 'ink'; +import { Colors } from '../colors.js'; import { CommandContext, SlashCommand, MessageActionReturn, CommandKind, + SlashCommandActionReturn, } from './types.js'; import path from 'path'; import { HistoryItemWithoutId, MessageType } from '../types.js'; @@ -96,7 +100,7 @@ const saveCommand: SlashCommand = { description: 'Save the current conversation as a checkpoint. Usage: /chat save ', kind: CommandKind.BUILT_IN, - action: async (context, args): Promise => { + action: async (context, args): Promise => { const tag = args.trim(); if (!tag) { return { @@ -108,6 +112,26 @@ const saveCommand: SlashCommand = { const { logger, config } = context.services; await logger.initialize(); + + if (!context.overwriteConfirmed) { + const exists = await logger.checkpointExists(tag); + if (exists) { + return { + type: 'confirm_action', + prompt: React.createElement( + Text, + null, + 'A checkpoint with the tag ', + React.createElement(Text, { color: Colors.AccentPurple }, tag), + ' already exists. Do you want to overwrite it?', + ), + originalInvocation: { + raw: context.invocation?.raw || `/chat save ${tag}`, + }, + }; + } + } + const chat = await config?.getGeminiClient()?.getChat(); if (!chat) { return { diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index 09d79e9d..529f4eb8 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { type ReactNode } from 'react'; import { Content } from '@google/genai'; import { HistoryItemWithoutId } from '../types.js'; import { Config, GitService, Logger } from '@google/gemini-cli-core'; @@ -67,6 +68,8 @@ export interface CommandContext { /** A transient list of shell commands the user has approved for this session. */ sessionShellAllowlist: Set; }; + // Flag to indicate if an overwrite has been confirmed + overwriteConfirmed?: boolean; } /** @@ -135,6 +138,16 @@ export interface ConfirmShellCommandsActionReturn { }; } +export interface ConfirmActionReturn { + type: 'confirm_action'; + /** The React node to display as the confirmation prompt. */ + prompt: ReactNode; + /** The original invocation context to be re-run after confirmation. */ + originalInvocation: { + raw: string; + }; +} + export type SlashCommandActionReturn = | ToolActionReturn | MessageActionReturn @@ -142,7 +155,8 @@ export type SlashCommandActionReturn = | OpenDialogActionReturn | LoadHistoryActionReturn | SubmitPromptActionReturn - | ConfirmShellCommandsActionReturn; + | ConfirmShellCommandsActionReturn + | ConfirmActionReturn; export enum CommandKind { BUILT_IN = 'built-in', diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 9f4bbf90..ca08abb1 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -64,6 +64,11 @@ export const useSlashCommandProcessor = ( approvedCommands?: string[], ) => void; }>(null); + const [confirmationRequest, setConfirmationRequest] = useState void; + }>(null); + const [sessionShellAllowlist, setSessionShellAllowlist] = useState( new Set(), ); @@ -220,6 +225,7 @@ export const useSlashCommandProcessor = ( async ( rawQuery: PartListUnion, oneTimeShellAllowlist?: Set, + overwriteConfirmed?: boolean, ): Promise => { setIsProcessing(true); try { @@ -299,6 +305,7 @@ export const useSlashCommandProcessor = ( name: commandToExecute.name, args, }, + overwriteConfirmed, }; // If a one-time list is provided for a "Proceed" action, temporarily @@ -422,6 +429,36 @@ export const useSlashCommandProcessor = ( new Set(approvedCommands), ); } + case 'confirm_action': { + const { confirmed } = await new Promise<{ + confirmed: boolean; + }>((resolve) => { + setConfirmationRequest({ + prompt: result.prompt, + onConfirm: (resolvedConfirmed) => { + setConfirmationRequest(null); + resolve({ confirmed: resolvedConfirmed }); + }, + }); + }); + + if (!confirmed) { + addItem( + { + type: MessageType.INFO, + text: 'Operation cancelled.', + }, + Date.now(), + ); + return { type: 'handled' }; + } + + return await handleSlashCommand( + result.originalInvocation.raw, + undefined, + true, + ); + } default: { const unhandled: never = result; throw new Error( @@ -478,6 +515,7 @@ export const useSlashCommandProcessor = ( setShellConfirmationRequest, setSessionShellAllowlist, setIsProcessing, + setConfirmationRequest, ], ); @@ -487,5 +525,6 @@ export const useSlashCommandProcessor = ( pendingHistoryItems, commandContext, shellConfirmationRequest, + confirmationRequest, }; }; diff --git a/packages/core/src/core/logger.test.ts b/packages/core/src/core/logger.test.ts index 3f243b52..d032e2d4 100644 --- a/packages/core/src/core/logger.test.ts +++ b/packages/core/src/core/logger.test.ts @@ -565,6 +565,52 @@ describe('Logger', () => { }); }); + describe('checkpointExists', () => { + const tag = 'exists-test'; + let taggedFilePath: string; + + beforeEach(() => { + taggedFilePath = path.join(TEST_GEMINI_DIR, `checkpoint-${tag}.json`); + }); + + it('should return true if the checkpoint file exists', async () => { + await fs.writeFile(taggedFilePath, '{}'); + const exists = await logger.checkpointExists(tag); + expect(exists).toBe(true); + }); + + it('should return false if the checkpoint file does not exist', async () => { + const exists = await logger.checkpointExists('non-existent-tag'); + expect(exists).toBe(false); + }); + + it('should throw an error if logger is not initialized', async () => { + const uninitializedLogger = new Logger(testSessionId); + uninitializedLogger.close(); + + await expect(uninitializedLogger.checkpointExists(tag)).rejects.toThrow( + 'Logger not initialized. Cannot check for checkpoint existence.', + ); + }); + + it('should re-throw an error if fs.access fails for reasons other than not existing', async () => { + vi.spyOn(fs, 'access').mockRejectedValueOnce( + new Error('EACCES: permission denied'), + ); + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + + await expect(logger.checkpointExists(tag)).rejects.toThrow( + 'EACCES: permission denied', + ); + expect(consoleErrorSpy).toHaveBeenCalledWith( + `Failed to check checkpoint existence for ${taggedFilePath}:`, + expect.any(Error), + ); + }); + }); + describe('close', () => { it('should reset logger state', async () => { await logger.logMessage(MessageSenderType.USER, 'A message'); diff --git a/packages/core/src/core/logger.ts b/packages/core/src/core/logger.ts index 9f4622e7..f4857f47 100644 --- a/packages/core/src/core/logger.ts +++ b/packages/core/src/core/logger.ts @@ -310,6 +310,29 @@ export class Logger { } } + async checkpointExists(tag: string): Promise { + if (!this.initialized) { + throw new Error( + 'Logger not initialized. Cannot check for checkpoint existence.', + ); + } + const filePath = this._checkpointPath(tag); + try { + await fs.access(filePath); + return true; + } catch (error) { + const nodeError = error as NodeJS.ErrnoException; + if (nodeError.code === 'ENOENT') { + return false; + } + console.error( + `Failed to check checkpoint existence for ${filePath}:`, + error, + ); + throw error; + } + } + close(): void { this.initialized = false; this.logFilePath = undefined;