refactor: Improve editCorrector logic and type safety

- Refactor `ensureCorrectEdit` to clarify the correction flow for `old_string` and `new_string`.
- Only correct `new_string` if it was potentially escaped; otherwise, use the original.
- Introduce `CorrectedEditParams` and `CorrectedEditResult` interfaces for better type definition.
- Relocate `countOccurrences` for better logical grouping.

Part of https://github.com/google-gemini/gemini-cli/issues/484
This commit is contained in:
Taylor Mullen 2025-05-25 13:15:12 -07:00 committed by N. Taylor Mullen
parent 068b505d5e
commit fa4a04157f
1 changed files with 125 additions and 86 deletions

View File

@ -21,87 +21,128 @@ const EditConfig: GenerateContentConfig = {
};
/**
* Counts occurrences of a substring in a string
* Defines the structure of the parameters within CorrectedEditResult
*/
export function countOccurrences(str: string, substr: string): number {
if (substr === '') {
return 0;
}
let count = 0;
let pos = str.indexOf(substr);
while (pos !== -1) {
count++;
pos = str.indexOf(substr, pos + substr.length); // Start search after the current match
}
return count;
interface CorrectedEditParams {
file_path: string;
old_string: string;
new_string: string;
}
/**
* Defines the result structure for ensureCorrectEdit.
*/
export interface CorrectedEditResult {
params: CorrectedEditParams;
occurrences: number;
}
/**
* Attempts to correct edit parameters if the original old_string is not found.
* It tries unescaping, and then LLM-based correction.
* Results are cached to avoid redundant processing.
*
* @param currentContent The current content of the file.
* @param params The original EditToolParams.
* @param originalParams The original EditToolParams
* @param client The GeminiClient for LLM calls.
* @returns A promise resolving to an object containing the (potentially corrected) EditToolParams and the final occurrences count.
* @returns A promise resolving to an object containing the (potentially corrected)
* EditToolParams (as CorrectedEditParams) and the final occurrences count.
*/
export async function ensureCorrectEdit(
currentContent: string,
originalParams: EditToolParams,
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
client: GeminiClient,
): Promise<CorrectedEditResult> {
let occurrences = countOccurrences(currentContent, originalParams.old_string);
const currentParams = { ...originalParams };
let finalNewString = originalParams.new_string;
const newStringPotentiallyEscaped =
unescapeStringForGeminiBug(originalParams.new_string) !==
originalParams.new_string;
let finalOldString = originalParams.old_string;
let occurrences = countOccurrences(currentContent, finalOldString);
if (occurrences === 1) {
return { params: currentParams, occurrences };
}
const unescapedOldString = unescapeStringForGeminiBug(
currentParams.old_string,
);
occurrences = countOccurrences(currentContent, unescapedOldString);
if (occurrences === 1) {
currentParams.old_string = unescapedOldString;
currentParams.new_string = unescapeStringForGeminiBug(
currentParams.new_string,
return { params: originalParams, occurrences };
} else {
// occurrences is 0 or some other unexpected state initially
const unescapedOldStringAttempt = unescapeStringForGeminiBug(
originalParams.old_string,
);
} else if (occurrences === 0) {
const llmCorrectedOldString = await correctOldStringMismatch(
client,
currentContent,
unescapedOldString,
);
occurrences = countOccurrences(currentContent, llmCorrectedOldString);
occurrences = countOccurrences(currentContent, unescapedOldStringAttempt);
if (occurrences === 1) {
const llmCorrectedNewString = await correctNewString(
finalOldString = unescapedOldStringAttempt;
finalNewString = unescapeStringForGeminiBug(originalParams.new_string);
} else if (occurrences === 0) {
const llmCorrectedOldString = await correctOldStringMismatch(
client,
unescapedOldString,
llmCorrectedOldString,
currentParams.new_string,
currentContent,
unescapedOldStringAttempt,
);
currentParams.old_string = llmCorrectedOldString;
currentParams.new_string = llmCorrectedNewString;
const llmOldOccurrences = countOccurrences(
currentContent,
llmCorrectedOldString,
);
if (llmOldOccurrences === 1) {
finalOldString = llmCorrectedOldString;
occurrences = llmOldOccurrences;
if (newStringPotentiallyEscaped) {
const baseNewStringForLLMCorrection = unescapeStringForGeminiBug(
originalParams.new_string,
);
finalNewString = await correctNewString(
client,
originalParams.old_string, // original old
llmCorrectedOldString, // corrected old
baseNewStringForLLMCorrection, // base new for correction
);
}
} else {
// LLM correction also failed for old_string
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences: 0, // Explicitly 0 as LLM failed
};
return result;
}
} else {
// If LLM correction also results in 0 or >1 occurrences,
// return the original params and 0 occurrences,
// letting the caller handle the "still not found" case.
return { params: originalParams, occurrences: 0 };
// Unescaping old_string resulted in > 1 occurrences
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences, // This will be > 1
};
return result;
}
} else {
// If unescaping resulted in >1 occurrences, return original params and that count.
return { params: originalParams, occurrences };
}
return { params: currentParams, occurrences };
// Final result construction
const result: CorrectedEditResult = {
params: {
file_path: originalParams.file_path,
old_string: finalOldString,
new_string: finalNewString,
},
occurrences: countOccurrences(currentContent, finalOldString), // Recalculate occurrences with the final old_string
};
return result;
}
/**
* Attempts to correct potential formatting/escaping issues in a snippet using an LLM call.
*/
async function correctOldStringMismatch(
// Define the expected JSON schema for the LLM response for old_string correction
const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_target_snippet: {
type: Type.STRING,
description:
'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.',
},
},
required: ['corrected_target_snippet'],
};
export async function correctOldStringMismatch(
geminiClient: GeminiClient,
fileContent: string,
problematicSnippet: string,
@ -155,10 +196,23 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
}
}
// Define the expected JSON schema for the new_string correction LLM response
const NEW_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_new_string: {
type: Type.STRING,
description:
'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.',
},
},
required: ['corrected_new_string'],
};
/**
* Adjusts the new_string to align with a corrected old_string, maintaining the original intent.
*/
async function correctNewString(
export async function correctNewString(
geminiClient: GeminiClient,
originalOldString: string,
correctedOldString: string,
@ -220,37 +274,6 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
}
}
export interface CorrectedEditResult {
params: EditToolParams;
occurrences: number;
}
// Define the expected JSON schema for the LLM response for old_string correction
const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_target_snippet: {
type: Type.STRING,
description:
'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.',
},
},
required: ['corrected_target_snippet'],
};
// Define the expected JSON schema for the new_string correction LLM response
const NEW_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_new_string: {
type: Type.STRING,
description:
'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.',
},
},
required: ['corrected_new_string'],
};
/**
* Unescapes a string that might have been overly escaped by an LLM.
*/
@ -290,3 +313,19 @@ export function unescapeStringForGeminiBug(inputString: string): string {
}
});
}
/**
* Counts occurrences of a substring in a string
*/
export function countOccurrences(str: string, substr: string): number {
if (substr === '') {
return 0;
}
let count = 0;
let pos = str.indexOf(substr);
while (pos !== -1) {
count++;
pos = str.indexOf(substr, pos + substr.length); // Start search after the current match
}
return count;
}