feat(chat): Add overwrite confirmation dialog to `/chat save` (#5686)
Co-authored-by: Jacob Richman <jacob314@gmail.com>
This commit is contained in:
parent
191cc01bf5
commit
6487cc1689
|
@ -39,6 +39,7 @@ import { AuthInProgress } from './components/AuthInProgress.js';
|
||||||
import { EditorSettingsDialog } from './components/EditorSettingsDialog.js';
|
import { EditorSettingsDialog } from './components/EditorSettingsDialog.js';
|
||||||
import { FolderTrustDialog } from './components/FolderTrustDialog.js';
|
import { FolderTrustDialog } from './components/FolderTrustDialog.js';
|
||||||
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
|
import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js';
|
||||||
|
import { RadioButtonSelect } from './components/shared/RadioButtonSelect.js';
|
||||||
import { Colors } from './colors.js';
|
import { Colors } from './colors.js';
|
||||||
import { loadHierarchicalGeminiMemory } from '../config/config.js';
|
import { loadHierarchicalGeminiMemory } from '../config/config.js';
|
||||||
import { LoadedSettings, SettingScope } from '../config/settings.js';
|
import { LoadedSettings, SettingScope } from '../config/settings.js';
|
||||||
|
@ -488,6 +489,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||||
pendingHistoryItems: pendingSlashCommandHistoryItems,
|
pendingHistoryItems: pendingSlashCommandHistoryItems,
|
||||||
commandContext,
|
commandContext,
|
||||||
shellConfirmationRequest,
|
shellConfirmationRequest,
|
||||||
|
confirmationRequest,
|
||||||
} = useSlashCommandProcessor(
|
} = useSlashCommandProcessor(
|
||||||
config,
|
config,
|
||||||
settings,
|
settings,
|
||||||
|
@ -912,6 +914,21 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => {
|
||||||
<FolderTrustDialog onSelect={handleFolderTrustSelect} />
|
<FolderTrustDialog onSelect={handleFolderTrustSelect} />
|
||||||
) : shellConfirmationRequest ? (
|
) : shellConfirmationRequest ? (
|
||||||
<ShellConfirmationDialog request={shellConfirmationRequest} />
|
<ShellConfirmationDialog request={shellConfirmationRequest} />
|
||||||
|
) : confirmationRequest ? (
|
||||||
|
<Box flexDirection="column">
|
||||||
|
{confirmationRequest.prompt}
|
||||||
|
<Box paddingY={1}>
|
||||||
|
<RadioButtonSelect
|
||||||
|
items={[
|
||||||
|
{ label: 'Yes', value: true },
|
||||||
|
{ label: 'No', value: false },
|
||||||
|
]}
|
||||||
|
onSelect={(value: boolean) => {
|
||||||
|
confirmationRequest.onConfirm(value);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
) : isThemeDialogOpen ? (
|
) : isThemeDialogOpen ? (
|
||||||
<Box flexDirection="column">
|
<Box flexDirection="column">
|
||||||
{themeError && (
|
{themeError && (
|
||||||
|
|
|
@ -168,8 +168,12 @@ describe('chatCommand', () => {
|
||||||
describe('save subcommand', () => {
|
describe('save subcommand', () => {
|
||||||
let saveCommand: SlashCommand;
|
let saveCommand: SlashCommand;
|
||||||
const tag = 'my-tag';
|
const tag = 'my-tag';
|
||||||
|
let mockCheckpointExists: ReturnType<typeof vi.fn>;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
saveCommand = getSubCommand('save');
|
saveCommand = getSubCommand('save');
|
||||||
|
mockCheckpointExists = vi.fn().mockResolvedValue(false);
|
||||||
|
mockContext.services.logger.checkpointExists = mockCheckpointExists;
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return an error if tag is missing', async () => {
|
it('should return an error if tag is missing', async () => {
|
||||||
|
@ -191,7 +195,7 @@ describe('chatCommand', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should save the conversation', async () => {
|
it('should save the conversation if checkpoint does not exist', async () => {
|
||||||
const history: HistoryItemWithoutId[] = [
|
const history: HistoryItemWithoutId[] = [
|
||||||
{
|
{
|
||||||
type: 'user',
|
type: 'user',
|
||||||
|
@ -199,8 +203,52 @@ describe('chatCommand', () => {
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
mockGetHistory.mockReturnValue(history);
|
mockGetHistory.mockReturnValue(history);
|
||||||
|
mockCheckpointExists.mockResolvedValue(false);
|
||||||
|
|
||||||
const result = await saveCommand?.action?.(mockContext, tag);
|
const result = await saveCommand?.action?.(mockContext, tag);
|
||||||
|
|
||||||
|
expect(mockCheckpointExists).toHaveBeenCalledWith(tag);
|
||||||
|
expect(mockSaveCheckpoint).toHaveBeenCalledWith(history, tag);
|
||||||
|
expect(result).toEqual({
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'info',
|
||||||
|
content: `Conversation checkpoint saved with tag: ${tag}.`,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return confirm_action if checkpoint already exists', async () => {
|
||||||
|
mockCheckpointExists.mockResolvedValue(true);
|
||||||
|
mockContext.invocation = {
|
||||||
|
raw: `/chat save ${tag}`,
|
||||||
|
name: 'save',
|
||||||
|
args: tag,
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await saveCommand?.action?.(mockContext, tag);
|
||||||
|
|
||||||
|
expect(mockCheckpointExists).toHaveBeenCalledWith(tag);
|
||||||
|
expect(mockSaveCheckpoint).not.toHaveBeenCalled();
|
||||||
|
expect(result).toMatchObject({
|
||||||
|
type: 'confirm_action',
|
||||||
|
originalInvocation: { raw: `/chat save ${tag}` },
|
||||||
|
});
|
||||||
|
// Check that prompt is a React element
|
||||||
|
expect(result).toHaveProperty('prompt');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should save the conversation if overwrite is confirmed', async () => {
|
||||||
|
const history: HistoryItemWithoutId[] = [
|
||||||
|
{
|
||||||
|
type: 'user',
|
||||||
|
text: 'hello',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
mockGetHistory.mockReturnValue(history);
|
||||||
|
mockContext.overwriteConfirmed = true;
|
||||||
|
|
||||||
|
const result = await saveCommand?.action?.(mockContext, tag);
|
||||||
|
|
||||||
|
expect(mockCheckpointExists).not.toHaveBeenCalled(); // Should skip existence check
|
||||||
expect(mockSaveCheckpoint).toHaveBeenCalledWith(history, tag);
|
expect(mockSaveCheckpoint).toHaveBeenCalledWith(history, tag);
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
type: 'message',
|
type: 'message',
|
||||||
|
|
|
@ -5,11 +5,15 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import * as fsPromises from 'fs/promises';
|
import * as fsPromises from 'fs/promises';
|
||||||
|
import React from 'react';
|
||||||
|
import { Text } from 'ink';
|
||||||
|
import { Colors } from '../colors.js';
|
||||||
import {
|
import {
|
||||||
CommandContext,
|
CommandContext,
|
||||||
SlashCommand,
|
SlashCommand,
|
||||||
MessageActionReturn,
|
MessageActionReturn,
|
||||||
CommandKind,
|
CommandKind,
|
||||||
|
SlashCommandActionReturn,
|
||||||
} from './types.js';
|
} from './types.js';
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
import { HistoryItemWithoutId, MessageType } from '../types.js';
|
import { HistoryItemWithoutId, MessageType } from '../types.js';
|
||||||
|
@ -96,7 +100,7 @@ const saveCommand: SlashCommand = {
|
||||||
description:
|
description:
|
||||||
'Save the current conversation as a checkpoint. Usage: /chat save <tag>',
|
'Save the current conversation as a checkpoint. Usage: /chat save <tag>',
|
||||||
kind: CommandKind.BUILT_IN,
|
kind: CommandKind.BUILT_IN,
|
||||||
action: async (context, args): Promise<MessageActionReturn> => {
|
action: async (context, args): Promise<SlashCommandActionReturn | void> => {
|
||||||
const tag = args.trim();
|
const tag = args.trim();
|
||||||
if (!tag) {
|
if (!tag) {
|
||||||
return {
|
return {
|
||||||
|
@ -108,6 +112,26 @@ const saveCommand: SlashCommand = {
|
||||||
|
|
||||||
const { logger, config } = context.services;
|
const { logger, config } = context.services;
|
||||||
await logger.initialize();
|
await logger.initialize();
|
||||||
|
|
||||||
|
if (!context.overwriteConfirmed) {
|
||||||
|
const exists = await logger.checkpointExists(tag);
|
||||||
|
if (exists) {
|
||||||
|
return {
|
||||||
|
type: 'confirm_action',
|
||||||
|
prompt: React.createElement(
|
||||||
|
Text,
|
||||||
|
null,
|
||||||
|
'A checkpoint with the tag ',
|
||||||
|
React.createElement(Text, { color: Colors.AccentPurple }, tag),
|
||||||
|
' already exists. Do you want to overwrite it?',
|
||||||
|
),
|
||||||
|
originalInvocation: {
|
||||||
|
raw: context.invocation?.raw || `/chat save ${tag}`,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const chat = await config?.getGeminiClient()?.getChat();
|
const chat = await config?.getGeminiClient()?.getChat();
|
||||||
if (!chat) {
|
if (!chat) {
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import { type ReactNode } from 'react';
|
||||||
import { Content } from '@google/genai';
|
import { Content } from '@google/genai';
|
||||||
import { HistoryItemWithoutId } from '../types.js';
|
import { HistoryItemWithoutId } from '../types.js';
|
||||||
import { Config, GitService, Logger } from '@google/gemini-cli-core';
|
import { Config, GitService, Logger } from '@google/gemini-cli-core';
|
||||||
|
@ -67,6 +68,8 @@ export interface CommandContext {
|
||||||
/** A transient list of shell commands the user has approved for this session. */
|
/** A transient list of shell commands the user has approved for this session. */
|
||||||
sessionShellAllowlist: Set<string>;
|
sessionShellAllowlist: Set<string>;
|
||||||
};
|
};
|
||||||
|
// Flag to indicate if an overwrite has been confirmed
|
||||||
|
overwriteConfirmed?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -135,6 +138,16 @@ export interface ConfirmShellCommandsActionReturn {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ConfirmActionReturn {
|
||||||
|
type: 'confirm_action';
|
||||||
|
/** The React node to display as the confirmation prompt. */
|
||||||
|
prompt: ReactNode;
|
||||||
|
/** The original invocation context to be re-run after confirmation. */
|
||||||
|
originalInvocation: {
|
||||||
|
raw: string;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export type SlashCommandActionReturn =
|
export type SlashCommandActionReturn =
|
||||||
| ToolActionReturn
|
| ToolActionReturn
|
||||||
| MessageActionReturn
|
| MessageActionReturn
|
||||||
|
@ -142,7 +155,8 @@ export type SlashCommandActionReturn =
|
||||||
| OpenDialogActionReturn
|
| OpenDialogActionReturn
|
||||||
| LoadHistoryActionReturn
|
| LoadHistoryActionReturn
|
||||||
| SubmitPromptActionReturn
|
| SubmitPromptActionReturn
|
||||||
| ConfirmShellCommandsActionReturn;
|
| ConfirmShellCommandsActionReturn
|
||||||
|
| ConfirmActionReturn;
|
||||||
|
|
||||||
export enum CommandKind {
|
export enum CommandKind {
|
||||||
BUILT_IN = 'built-in',
|
BUILT_IN = 'built-in',
|
||||||
|
|
|
@ -64,6 +64,11 @@ export const useSlashCommandProcessor = (
|
||||||
approvedCommands?: string[],
|
approvedCommands?: string[],
|
||||||
) => void;
|
) => void;
|
||||||
}>(null);
|
}>(null);
|
||||||
|
const [confirmationRequest, setConfirmationRequest] = useState<null | {
|
||||||
|
prompt: React.ReactNode;
|
||||||
|
onConfirm: (confirmed: boolean) => void;
|
||||||
|
}>(null);
|
||||||
|
|
||||||
const [sessionShellAllowlist, setSessionShellAllowlist] = useState(
|
const [sessionShellAllowlist, setSessionShellAllowlist] = useState(
|
||||||
new Set<string>(),
|
new Set<string>(),
|
||||||
);
|
);
|
||||||
|
@ -220,6 +225,7 @@ export const useSlashCommandProcessor = (
|
||||||
async (
|
async (
|
||||||
rawQuery: PartListUnion,
|
rawQuery: PartListUnion,
|
||||||
oneTimeShellAllowlist?: Set<string>,
|
oneTimeShellAllowlist?: Set<string>,
|
||||||
|
overwriteConfirmed?: boolean,
|
||||||
): Promise<SlashCommandProcessorResult | false> => {
|
): Promise<SlashCommandProcessorResult | false> => {
|
||||||
setIsProcessing(true);
|
setIsProcessing(true);
|
||||||
try {
|
try {
|
||||||
|
@ -299,6 +305,7 @@ export const useSlashCommandProcessor = (
|
||||||
name: commandToExecute.name,
|
name: commandToExecute.name,
|
||||||
args,
|
args,
|
||||||
},
|
},
|
||||||
|
overwriteConfirmed,
|
||||||
};
|
};
|
||||||
|
|
||||||
// If a one-time list is provided for a "Proceed" action, temporarily
|
// If a one-time list is provided for a "Proceed" action, temporarily
|
||||||
|
@ -422,6 +429,36 @@ export const useSlashCommandProcessor = (
|
||||||
new Set(approvedCommands),
|
new Set(approvedCommands),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
case 'confirm_action': {
|
||||||
|
const { confirmed } = await new Promise<{
|
||||||
|
confirmed: boolean;
|
||||||
|
}>((resolve) => {
|
||||||
|
setConfirmationRequest({
|
||||||
|
prompt: result.prompt,
|
||||||
|
onConfirm: (resolvedConfirmed) => {
|
||||||
|
setConfirmationRequest(null);
|
||||||
|
resolve({ confirmed: resolvedConfirmed });
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!confirmed) {
|
||||||
|
addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.INFO,
|
||||||
|
text: 'Operation cancelled.',
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
return { type: 'handled' };
|
||||||
|
}
|
||||||
|
|
||||||
|
return await handleSlashCommand(
|
||||||
|
result.originalInvocation.raw,
|
||||||
|
undefined,
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
const unhandled: never = result;
|
const unhandled: never = result;
|
||||||
throw new Error(
|
throw new Error(
|
||||||
|
@ -478,6 +515,7 @@ export const useSlashCommandProcessor = (
|
||||||
setShellConfirmationRequest,
|
setShellConfirmationRequest,
|
||||||
setSessionShellAllowlist,
|
setSessionShellAllowlist,
|
||||||
setIsProcessing,
|
setIsProcessing,
|
||||||
|
setConfirmationRequest,
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -487,5 +525,6 @@ export const useSlashCommandProcessor = (
|
||||||
pendingHistoryItems,
|
pendingHistoryItems,
|
||||||
commandContext,
|
commandContext,
|
||||||
shellConfirmationRequest,
|
shellConfirmationRequest,
|
||||||
|
confirmationRequest,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
|
@ -565,6 +565,52 @@ describe('Logger', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('checkpointExists', () => {
|
||||||
|
const tag = 'exists-test';
|
||||||
|
let taggedFilePath: string;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
taggedFilePath = path.join(TEST_GEMINI_DIR, `checkpoint-${tag}.json`);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return true if the checkpoint file exists', async () => {
|
||||||
|
await fs.writeFile(taggedFilePath, '{}');
|
||||||
|
const exists = await logger.checkpointExists(tag);
|
||||||
|
expect(exists).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return false if the checkpoint file does not exist', async () => {
|
||||||
|
const exists = await logger.checkpointExists('non-existent-tag');
|
||||||
|
expect(exists).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error if logger is not initialized', async () => {
|
||||||
|
const uninitializedLogger = new Logger(testSessionId);
|
||||||
|
uninitializedLogger.close();
|
||||||
|
|
||||||
|
await expect(uninitializedLogger.checkpointExists(tag)).rejects.toThrow(
|
||||||
|
'Logger not initialized. Cannot check for checkpoint existence.',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should re-throw an error if fs.access fails for reasons other than not existing', async () => {
|
||||||
|
vi.spyOn(fs, 'access').mockRejectedValueOnce(
|
||||||
|
new Error('EACCES: permission denied'),
|
||||||
|
);
|
||||||
|
const consoleErrorSpy = vi
|
||||||
|
.spyOn(console, 'error')
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await expect(logger.checkpointExists(tag)).rejects.toThrow(
|
||||||
|
'EACCES: permission denied',
|
||||||
|
);
|
||||||
|
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||||
|
`Failed to check checkpoint existence for ${taggedFilePath}:`,
|
||||||
|
expect.any(Error),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('close', () => {
|
describe('close', () => {
|
||||||
it('should reset logger state', async () => {
|
it('should reset logger state', async () => {
|
||||||
await logger.logMessage(MessageSenderType.USER, 'A message');
|
await logger.logMessage(MessageSenderType.USER, 'A message');
|
||||||
|
|
|
@ -310,6 +310,29 @@ export class Logger {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async checkpointExists(tag: string): Promise<boolean> {
|
||||||
|
if (!this.initialized) {
|
||||||
|
throw new Error(
|
||||||
|
'Logger not initialized. Cannot check for checkpoint existence.',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const filePath = this._checkpointPath(tag);
|
||||||
|
try {
|
||||||
|
await fs.access(filePath);
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
const nodeError = error as NodeJS.ErrnoException;
|
||||||
|
if (nodeError.code === 'ENOENT') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
console.error(
|
||||||
|
`Failed to check checkpoint existence for ${filePath}:`,
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
close(): void {
|
close(): void {
|
||||||
this.initialized = false;
|
this.initialized = false;
|
||||||
this.logFilePath = undefined;
|
this.logFilePath = undefined;
|
||||||
|
|
Loading…
Reference in New Issue