Fix(write-file): Correct over-escaping and improve content generation
- Leveraged existing edit correction technology from `edit.ts` to address over-escaping issues in `write-file.ts`. - Introduced `ensureCorrectFileContent` for correcting content in new files, where a simple "replace" isnt applicable. This uses a new LLM prompt tailored for correcting potentially problematic string escaping. - Added caching for `ensureCorrectFileContent` to optimize performance. - Refactored `write-file.ts` to integrate these corrections, improving the reliability of file content generation and modification. Part of https://github.com/google-gemini/gemini-cli/issues/484
This commit is contained in:
parent
1a5fe16b22
commit
5097b5a656
|
@ -18,7 +18,12 @@ import {
|
||||||
} from './tools.js';
|
} from './tools.js';
|
||||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||||
import { isNodeError } from '../utils/errors.js';
|
import { getErrorMessage, isNodeError } from '../utils/errors.js';
|
||||||
|
import {
|
||||||
|
ensureCorrectEdit,
|
||||||
|
ensureCorrectFileContent,
|
||||||
|
} from '../utils/editCorrector.js';
|
||||||
|
import { GeminiClient } from '../core/client.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parameters for the WriteFile tool
|
* Parameters for the WriteFile tool
|
||||||
|
@ -35,11 +40,19 @@ export interface WriteFileToolParams {
|
||||||
content: string;
|
content: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface GetCorrectedFileContentResult {
|
||||||
|
originalContent: string;
|
||||||
|
correctedContent: string;
|
||||||
|
fileExists: boolean;
|
||||||
|
error?: { message: string; code?: string };
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implementation of the WriteFile tool logic
|
* Implementation of the WriteFile tool logic
|
||||||
*/
|
*/
|
||||||
export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
static readonly Name: string = 'write_file';
|
static readonly Name: string = 'write_file';
|
||||||
|
private readonly client: GeminiClient;
|
||||||
|
|
||||||
constructor(private readonly config: Config) {
|
constructor(private readonly config: Config) {
|
||||||
super(
|
super(
|
||||||
|
@ -62,6 +75,8 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
type: 'object',
|
type: 'object',
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
this.client = new GeminiClient(this.config);
|
||||||
}
|
}
|
||||||
|
|
||||||
private isWithinRoot(pathToCheck: string): boolean {
|
private isWithinRoot(pathToCheck: string): boolean {
|
||||||
|
@ -135,23 +150,27 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const correctedContentResult = await this._getCorrectedFileContent(
|
||||||
|
params.file_path,
|
||||||
|
params.content,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (correctedContentResult.error) {
|
||||||
|
// If file exists but couldn't be read, we can't show a diff for confirmation.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { originalContent, correctedContent } = correctedContentResult;
|
||||||
const relativePath = makeRelative(
|
const relativePath = makeRelative(
|
||||||
params.file_path,
|
params.file_path,
|
||||||
this.config.getTargetDir(),
|
this.config.getTargetDir(),
|
||||||
);
|
);
|
||||||
const fileName = path.basename(params.file_path);
|
const fileName = path.basename(params.file_path);
|
||||||
|
|
||||||
let currentContent = '';
|
|
||||||
try {
|
|
||||||
currentContent = fs.readFileSync(params.file_path, 'utf8');
|
|
||||||
} catch {
|
|
||||||
// File might not exist, that's okay for write/create
|
|
||||||
}
|
|
||||||
|
|
||||||
const fileDiff = Diff.createPatch(
|
const fileDiff = Diff.createPatch(
|
||||||
fileName,
|
fileName,
|
||||||
currentContent,
|
originalContent, // Original content (empty if new file or unreadable)
|
||||||
params.content,
|
correctedContent, // Content after potential correction
|
||||||
'Current',
|
'Current',
|
||||||
'Proposed',
|
'Proposed',
|
||||||
{ context: 3 },
|
{ context: 3 },
|
||||||
|
@ -183,37 +202,53 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
let currentContent = '';
|
const correctedContentResult = await this._getCorrectedFileContent(
|
||||||
let isNewFile = false;
|
params.file_path,
|
||||||
try {
|
params.content,
|
||||||
currentContent = fs.readFileSync(params.file_path, 'utf8');
|
);
|
||||||
} catch (err: unknown) {
|
|
||||||
if (isNodeError(err) && err.code === 'ENOENT') {
|
if (correctedContentResult.error) {
|
||||||
isNewFile = true;
|
const errDetails = correctedContentResult.error;
|
||||||
} else {
|
const errorMsg = `Error checking existing file: ${errDetails.message}`;
|
||||||
// Rethrow other read errors (permissions etc.)
|
return {
|
||||||
const errorMsg = `Error checking existing file: ${err instanceof Error ? err.message : String(err)}`;
|
llmContent: `Error checking existing file ${params.file_path}: ${errDetails.message}`,
|
||||||
return {
|
returnDisplay: errorMsg,
|
||||||
llmContent: `Error checking existing file ${params.file_path}: ${errorMsg}`,
|
};
|
||||||
returnDisplay: `Error: ${errorMsg}`,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const {
|
||||||
|
originalContent,
|
||||||
|
correctedContent: fileContent,
|
||||||
|
fileExists,
|
||||||
|
} = correctedContentResult;
|
||||||
|
// fileExists is true if the file existed (and was readable or unreadable but caught by readError).
|
||||||
|
// fileExists is false if the file did not exist (ENOENT).
|
||||||
|
const isNewFile =
|
||||||
|
!fileExists ||
|
||||||
|
(correctedContentResult.error !== undefined &&
|
||||||
|
!correctedContentResult.fileExists);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const dirName = path.dirname(params.file_path);
|
const dirName = path.dirname(params.file_path);
|
||||||
if (!fs.existsSync(dirName)) {
|
if (!fs.existsSync(dirName)) {
|
||||||
fs.mkdirSync(dirName, { recursive: true });
|
fs.mkdirSync(dirName, { recursive: true });
|
||||||
}
|
}
|
||||||
|
|
||||||
fs.writeFileSync(params.file_path, params.content, 'utf8');
|
fs.writeFileSync(params.file_path, fileContent, 'utf8');
|
||||||
|
|
||||||
// Generate diff for display result
|
// Generate diff for display result
|
||||||
const fileName = path.basename(params.file_path);
|
const fileName = path.basename(params.file_path);
|
||||||
|
// If there was a readError, originalContent in correctedContentResult is '',
|
||||||
|
// but for the diff, we want to show the original content as it was before the write if possible.
|
||||||
|
// However, if it was unreadable, currentContentForDiff will be empty.
|
||||||
|
const currentContentForDiff = correctedContentResult.error
|
||||||
|
? '' // Or some indicator of unreadable content
|
||||||
|
: originalContent;
|
||||||
|
|
||||||
const fileDiff = Diff.createPatch(
|
const fileDiff = Diff.createPatch(
|
||||||
fileName,
|
fileName,
|
||||||
currentContent, // Empty if it was a new file
|
currentContentForDiff,
|
||||||
params.content,
|
fileContent,
|
||||||
'Original',
|
'Original',
|
||||||
'Written',
|
'Written',
|
||||||
{ context: 3 },
|
{ context: 3 },
|
||||||
|
@ -237,4 +272,58 @@ export class WriteFileTool extends BaseTool<WriteFileToolParams, ToolResult> {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async _getCorrectedFileContent(
|
||||||
|
filePath: string,
|
||||||
|
proposedContent: string,
|
||||||
|
): Promise<GetCorrectedFileContentResult> {
|
||||||
|
let originalContent = '';
|
||||||
|
let fileExists = false;
|
||||||
|
let correctedContent = proposedContent;
|
||||||
|
|
||||||
|
try {
|
||||||
|
originalContent = fs.readFileSync(filePath, 'utf8');
|
||||||
|
fileExists = true; // File exists and was read
|
||||||
|
} catch (err) {
|
||||||
|
if (isNodeError(err) && err.code === 'ENOENT') {
|
||||||
|
fileExists = false;
|
||||||
|
originalContent = '';
|
||||||
|
} else {
|
||||||
|
// File exists but could not be read (permissions, etc.)
|
||||||
|
fileExists = true; // Mark as existing but problematic
|
||||||
|
originalContent = ''; // Can't use its content
|
||||||
|
const error = {
|
||||||
|
message: getErrorMessage(err),
|
||||||
|
code: isNodeError(err) ? err.code : undefined,
|
||||||
|
};
|
||||||
|
// Return early as we can't proceed with content correction meaningfully
|
||||||
|
return { originalContent, correctedContent, fileExists, error };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If readError is set, we have returned.
|
||||||
|
// So, file was either read successfully (fileExists=true, originalContent set)
|
||||||
|
// or it was ENOENT (fileExists=false, originalContent='').
|
||||||
|
|
||||||
|
if (fileExists) {
|
||||||
|
// This implies originalContent is available
|
||||||
|
const { params: correctedParams } = await ensureCorrectEdit(
|
||||||
|
originalContent,
|
||||||
|
{
|
||||||
|
old_string: originalContent, // Treat entire current content as old_string
|
||||||
|
new_string: proposedContent,
|
||||||
|
file_path: filePath,
|
||||||
|
},
|
||||||
|
this.client,
|
||||||
|
);
|
||||||
|
correctedContent = correctedParams.new_string;
|
||||||
|
} else {
|
||||||
|
// This implies new file (ENOENT)
|
||||||
|
correctedContent = await ensureCorrectFileContent(
|
||||||
|
proposedContent,
|
||||||
|
this.client,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return { originalContent, correctedContent, fileExists };
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,9 @@ const editCorrectionCache = new LruCache<string, CorrectedEditResult>(
|
||||||
MAX_CACHE_SIZE,
|
MAX_CACHE_SIZE,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Cache for ensureCorrectFileContent results
|
||||||
|
const fileContentCorrectionCache = new LruCache<string, string>(MAX_CACHE_SIZE);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Defines the structure of the parameters within CorrectedEditResult
|
* Defines the structure of the parameters within CorrectedEditResult
|
||||||
*/
|
*/
|
||||||
|
@ -174,6 +177,27 @@ export async function ensureCorrectEdit(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function ensureCorrectFileContent(
|
||||||
|
content: string,
|
||||||
|
client: GeminiClient,
|
||||||
|
): Promise<string> {
|
||||||
|
const cachedResult = fileContentCorrectionCache.get(content);
|
||||||
|
if (cachedResult) {
|
||||||
|
return cachedResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
const contentPotentiallyEscaped =
|
||||||
|
unescapeStringForGeminiBug(content) !== content;
|
||||||
|
if (!contentPotentiallyEscaped) {
|
||||||
|
fileContentCorrectionCache.set(content, content);
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
|
const correctedContent = await correctStringEscaping(content, client);
|
||||||
|
fileContentCorrectionCache.set(content, correctedContent);
|
||||||
|
return correctedContent;
|
||||||
|
}
|
||||||
|
|
||||||
// Define the expected JSON schema for the LLM response for old_string correction
|
// Define the expected JSON schema for the LLM response for old_string correction
|
||||||
const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
|
const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
|
||||||
type: Type.OBJECT,
|
type: Type.OBJECT,
|
||||||
|
@ -385,6 +409,66 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
corrected_string_escaping: {
|
||||||
|
type: Type.STRING,
|
||||||
|
description:
|
||||||
|
'The string with corrected escaping, ensuring it is valid, specially considering potential over-escaping issues from previous LLM generations.',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ['corrected_string_escaping'],
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function correctStringEscaping(
|
||||||
|
potentiallyProblematicString: string,
|
||||||
|
client: GeminiClient,
|
||||||
|
): Promise<string> {
|
||||||
|
const prompt = `
|
||||||
|
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||||
|
|
||||||
|
potentially_problematic_string (this text MIGHT have bad escaping, or might be entirely correct):
|
||||||
|
\`\`\`
|
||||||
|
${potentiallyProblematicString}
|
||||||
|
\`\`\`
|
||||||
|
|
||||||
|
Task: Analyze the potentially_problematic_string. If it's syntactically invalid due to incorrect escaping (e.g., "\n", "\t", "\\", "\\'", "\\""), correct the invalid syntax. The goal is to ensure the text will be a valid and correctly interpreted.
|
||||||
|
|
||||||
|
For example, if potentially_problematic_string is "bar\\nbaz", the corrected_new_string_escaping should be "bar\nbaz".
|
||||||
|
If potentially_problematic_string is console.log(\\"Hello World\\"), it should be console.log("Hello World").
|
||||||
|
|
||||||
|
Return ONLY the corrected string in the specified JSON format with the key 'corrected_string_escaping'. If no escaping correction is needed, return the original potentially_problematic_string.
|
||||||
|
`.trim();
|
||||||
|
|
||||||
|
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await client.generateJson(
|
||||||
|
contents,
|
||||||
|
CORRECT_STRING_ESCAPING_SCHEMA,
|
||||||
|
EditModel,
|
||||||
|
EditConfig,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (
|
||||||
|
result &&
|
||||||
|
typeof result.corrected_new_string_escaping === 'string' &&
|
||||||
|
result.corrected_new_string_escaping.length > 0
|
||||||
|
) {
|
||||||
|
return result.corrected_new_string_escaping;
|
||||||
|
} else {
|
||||||
|
return potentiallyProblematicString;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(
|
||||||
|
'Error during LLM call for string escaping correction:',
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
return potentiallyProblematicString;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function trimPairIfPossible(
|
function trimPairIfPossible(
|
||||||
target: string,
|
target: string,
|
||||||
trimIfTargetTrims: string,
|
trimIfTargetTrims: string,
|
||||||
|
@ -470,4 +554,5 @@ export function countOccurrences(str: string, substr: string): number {
|
||||||
|
|
||||||
export function resetEditCorrectorCaches_TEST_ONLY() {
|
export function resetEditCorrectorCaches_TEST_ONLY() {
|
||||||
editCorrectionCache.clear();
|
editCorrectionCache.clear();
|
||||||
|
fileContentCorrectionCache.clear();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue