refactor: Improve editCorrector logic and type safety

- Refactor `ensureCorrectEdit` to clarify the correction flow for `old_string` and `new_string`.
- Only correct `new_string` if it was potentially escaped; otherwise, use the original.
- Introduce `CorrectedEditParams` and `CorrectedEditResult` interfaces for better type definition.
- Relocate `countOccurrences` for better logical grouping.

Part of https://github.com/google-gemini/gemini-cli/issues/484
This commit is contained in:
Taylor Mullen 2025-05-25 13:15:12 -07:00 committed by N. Taylor Mullen
parent 068b505d5e
commit fa4a04157f
1 changed files with 125 additions and 86 deletions

View File

@ -21,87 +21,128 @@ const EditConfig: GenerateContentConfig = {
}; };
/** /**
* Counts occurrences of a substring in a string * Defines the structure of the parameters within CorrectedEditResult
*/ */
export function countOccurrences(str: string, substr: string): number { interface CorrectedEditParams {
if (substr === '') { file_path: string;
return 0; old_string: string;
new_string: string;
} }
let count = 0;
let pos = str.indexOf(substr); /**
while (pos !== -1) { * Defines the result structure for ensureCorrectEdit.
count++; */
pos = str.indexOf(substr, pos + substr.length); // Start search after the current match export interface CorrectedEditResult {
} params: CorrectedEditParams;
return count; occurrences: number;
} }
/** /**
* Attempts to correct edit parameters if the original old_string is not found. * Attempts to correct edit parameters if the original old_string is not found.
* It tries unescaping, and then LLM-based correction. * It tries unescaping, and then LLM-based correction.
* Results are cached to avoid redundant processing.
* *
* @param currentContent The current content of the file. * @param currentContent The current content of the file.
* @param params The original EditToolParams. * @param originalParams The original EditToolParams
* @param client The GeminiClient for LLM calls. * @param client The GeminiClient for LLM calls.
* @returns A promise resolving to an object containing the (potentially corrected) EditToolParams and the final occurrences count. * @returns A promise resolving to an object containing the (potentially corrected)
* EditToolParams (as CorrectedEditParams) and the final occurrences count.
*/ */
export async function ensureCorrectEdit( export async function ensureCorrectEdit(
currentContent: string, currentContent: string,
originalParams: EditToolParams, originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
client: GeminiClient, client: GeminiClient,
): Promise<CorrectedEditResult> { ): Promise<CorrectedEditResult> {
let occurrences = countOccurrences(currentContent, originalParams.old_string); let finalNewString = originalParams.new_string;
const currentParams = { ...originalParams }; const newStringPotentiallyEscaped =
unescapeStringForGeminiBug(originalParams.new_string) !==
originalParams.new_string;
let finalOldString = originalParams.old_string;
let occurrences = countOccurrences(currentContent, finalOldString);
if (occurrences === 1) { if (occurrences === 1) {
return { params: currentParams, occurrences }; return { params: originalParams, occurrences };
} } else {
// occurrences is 0 or some other unexpected state initially
const unescapedOldString = unescapeStringForGeminiBug( const unescapedOldStringAttempt = unescapeStringForGeminiBug(
currentParams.old_string, originalParams.old_string,
); );
occurrences = countOccurrences(currentContent, unescapedOldString); occurrences = countOccurrences(currentContent, unescapedOldStringAttempt);
if (occurrences === 1) { if (occurrences === 1) {
currentParams.old_string = unescapedOldString; finalOldString = unescapedOldStringAttempt;
currentParams.new_string = unescapeStringForGeminiBug( finalNewString = unescapeStringForGeminiBug(originalParams.new_string);
currentParams.new_string,
);
} else if (occurrences === 0) { } else if (occurrences === 0) {
const llmCorrectedOldString = await correctOldStringMismatch( const llmCorrectedOldString = await correctOldStringMismatch(
client, client,
currentContent, currentContent,
unescapedOldString, unescapedOldStringAttempt,
); );
occurrences = countOccurrences(currentContent, llmCorrectedOldString); const llmOldOccurrences = countOccurrences(
currentContent,
if (occurrences === 1) {
const llmCorrectedNewString = await correctNewString(
client,
unescapedOldString,
llmCorrectedOldString, llmCorrectedOldString,
currentParams.new_string,
); );
currentParams.old_string = llmCorrectedOldString;
currentParams.new_string = llmCorrectedNewString; if (llmOldOccurrences === 1) {
} else { finalOldString = llmCorrectedOldString;
// If LLM correction also results in 0 or >1 occurrences, occurrences = llmOldOccurrences;
// return the original params and 0 occurrences,
// letting the caller handle the "still not found" case. if (newStringPotentiallyEscaped) {
return { params: originalParams, occurrences: 0 }; const baseNewStringForLLMCorrection = unescapeStringForGeminiBug(
originalParams.new_string,
);
finalNewString = await correctNewString(
client,
originalParams.old_string, // original old
llmCorrectedOldString, // corrected old
baseNewStringForLLMCorrection, // base new for correction
);
} }
} else { } else {
// If unescaping resulted in >1 occurrences, return original params and that count. // LLM correction also failed for old_string
return { params: originalParams, occurrences }; const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences: 0, // Explicitly 0 as LLM failed
};
return result;
}
} else {
// Unescaping old_string resulted in > 1 occurrences
const result: CorrectedEditResult = {
params: { ...originalParams },
occurrences, // This will be > 1
};
return result;
}
} }
return { params: currentParams, occurrences }; // Final result construction
const result: CorrectedEditResult = {
params: {
file_path: originalParams.file_path,
old_string: finalOldString,
new_string: finalNewString,
},
occurrences: countOccurrences(currentContent, finalOldString), // Recalculate occurrences with the final old_string
};
return result;
} }
/** // Define the expected JSON schema for the LLM response for old_string correction
* Attempts to correct potential formatting/escaping issues in a snippet using an LLM call. const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
*/ type: Type.OBJECT,
async function correctOldStringMismatch( properties: {
corrected_target_snippet: {
type: Type.STRING,
description:
'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.',
},
},
required: ['corrected_target_snippet'],
};
export async function correctOldStringMismatch(
geminiClient: GeminiClient, geminiClient: GeminiClient,
fileContent: string, fileContent: string,
problematicSnippet: string, problematicSnippet: string,
@ -155,10 +196,23 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
} }
} }
// Define the expected JSON schema for the new_string correction LLM response
const NEW_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_new_string: {
type: Type.STRING,
description:
'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.',
},
},
required: ['corrected_new_string'],
};
/** /**
* Adjusts the new_string to align with a corrected old_string, maintaining the original intent. * Adjusts the new_string to align with a corrected old_string, maintaining the original intent.
*/ */
async function correctNewString( export async function correctNewString(
geminiClient: GeminiClient, geminiClient: GeminiClient,
originalOldString: string, originalOldString: string,
correctedOldString: string, correctedOldString: string,
@ -220,37 +274,6 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
} }
} }
export interface CorrectedEditResult {
params: EditToolParams;
occurrences: number;
}
// Define the expected JSON schema for the LLM response for old_string correction
const OLD_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_target_snippet: {
type: Type.STRING,
description:
'The corrected version of the target snippet that exactly and uniquely matches a segment within the provided file content.',
},
},
required: ['corrected_target_snippet'],
};
// Define the expected JSON schema for the new_string correction LLM response
const NEW_STRING_CORRECTION_SCHEMA: SchemaUnion = {
type: Type.OBJECT,
properties: {
corrected_new_string: {
type: Type.STRING,
description:
'The original_new_string adjusted to be a suitable replacement for the corrected_old_string, while maintaining the original intent of the change.',
},
},
required: ['corrected_new_string'],
};
/** /**
* Unescapes a string that might have been overly escaped by an LLM. * Unescapes a string that might have been overly escaped by an LLM.
*/ */
@ -290,3 +313,19 @@ export function unescapeStringForGeminiBug(inputString: string): string {
} }
}); });
} }
/**
* Counts occurrences of a substring in a string
*/
export function countOccurrences(str: string, substr: string): number {
if (substr === '') {
return 0;
}
let count = 0;
let pos = str.indexOf(substr);
while (pos !== -1) {
count++;
pos = str.indexOf(substr, pos + substr.length); // Start search after the current match
}
return count;
}