diff --git a/manual_selection.py b/manual_selection.py index 9bb5ea5..d257648 100644 --- a/manual_selection.py +++ b/manual_selection.py @@ -1,8 +1,10 @@ from dataset import Dataset, Selection -import argparse, os +import argparse, os, re, click from enum import Enum from utils import prompt_yes_no +HUNK_HEADER_REGEX = re.compile(r'^@@ -\d+(?:,\d+)? \+\d+(?:,\d+)? @@') + class ValidationMode(Enum): COMMENT = "comment" @@ -38,10 +40,46 @@ def pretty_diff(after: str) -> str: return "\n".join(pretty_lines) +def split_into_hunks(diff: str) -> list[str]: + """ + Given a unified diff string, split it into chunks, each starting with a + hunk header (“@@ -… +… @@”) and including all context lines for that hunk. + """ + if not diff: + return [] + # The regex will keep the “@@ … @@” lines as the start of each hunk. + parts = re.split(r'(?m)(^@@ .*@@)', diff) + # re.split returns something like ['', header1, body1, header2, body2, …] + hunks = [] + for i in range(1, len(parts), 2): + header = parts[i] + body = parts[i + 1] + hunks.append(header + body) + return hunks + + +def edit_hunk(hunk: str) -> str: + while True: + edited = click.edit(hunk) + if edited is None: + print("Edit aborted, keeping original hunk") + return hunk + lines = edited.splitlines() + # Validate that the hunk header remains intact + if lines and HUNK_HEADER_REGEX.match(lines[0]): + return edited + else: + print(red("Invalid hunk header! Hunk header must not be modified.")) + if not prompt_yes_no("Edit again?", default=False): + print("Keeping original hunk") + return hunk + + def main( dataset_path: str, overwrite: bool = False, validation_mode: ValidationMode = ValidationMode.REFINEMENT, + check_diff_relevance: bool = False, ): dataset = Dataset.from_json(dataset_path) @@ -59,22 +97,41 @@ def main( try: n_good = 0 for i, entry in enumerate(entries_to_process, 1): - if entry.metadata.selection and not overwrite: - if entry.metadata.selection.good: - n_good += 1 - continue # Skip already processed - print("#" * os.get_terminal_size().columns) + sel = entry.metadata.selection + # Skip or count already processed entries if not overwriting + if not overwrite and sel is not None: + if ( + validation_mode == ValidationMode.COMMENT + and sel.comment_suggests_change is not None + ): + n_good += int(sel.comment_suggests_change) + continue + if ( + validation_mode == ValidationMode.REFINEMENT + and sel.diff_after_address_change is not None + ): + n_good += int(sel.diff_after_address_change) + # We'll re-ask diffs if needed below + # If selection exists but incomplete for this mode, proceed - pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}" + # Header info + print("#" * os.get_terminal_size().columns) print(f"# good PRs: {n_good}/{i} ({n_good/i:.2%})") print(f"Current PR: {i}/{len(entries_to_process)} ({i/len(entries_to_process):.2%})") + pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}" print(f"\nPull Request : {pr_url}") for comment in entry.comments: print("\nComment:", comment.body) - change = prompt_yes_no("Does this comment suggest a change?") - if not change: + # Comment suggestion check + if not overwrite and sel is not None and sel.comment_suggests_change is not None: + suggests = sel.comment_suggests_change + else: + suggests = prompt_yes_no("Does this comment suggest a change?") + + if not suggests: + print("Doesn't suggest any change, skipping...") entry.metadata.selection = Selection( comment_suggests_change=False, diff_after_address_change=None, @@ -83,7 +140,6 @@ def main( break if validation_mode == ValidationMode.COMMENT: - # In comment validation mode, we only check if the comment suggests a change entry.metadata.selection = Selection( comment_suggests_change=True, diff_after_address_change=None, @@ -91,24 +147,86 @@ def main( ) n_good += 1 break + + # REFINEMENT mode: show all diffs first + + # Initial relevance query + if not overwrite and sel is not None and sel.diff_after_address_change is not None: + any_relevant = sel.diff_after_address_change else: - # In refinement validation mode, we also check if the diff implements the change - for file, diff in entry.diffs_after.items(): + print("Diffs:") + for f, diff in entry.diffs_after.items(): if diff is None: - print(file, "EMPTY DIFF") + print(f, "EMPTY DIFF") continue - print(file, pretty_diff(diff)) - - applied = prompt_yes_no("Does this diff implement the change suggested?") + print(f"--- {f} ---") + print(pretty_diff(diff)) + any_relevant = prompt_yes_no("Are any of these diffs related to the comment?") + if not any_relevant: + print("No diffs relevant, skipping...") entry.metadata.selection = Selection( comment_suggests_change=True, - diff_after_address_change=applied, - good=applied, + diff_after_address_change=False, + good=False, ) - if applied: - n_good += 1 + break + # Ask which diffs if detailed relevance requested + relevant_diffs = {} + if check_diff_relevance: + for f, diff in entry.diffs_after.items(): + if diff is None: + continue + hunks = split_into_hunks(diff) + if not hunks: + continue + + print(f"\n--- {f} has {len(hunks)} hunks ---") + + selected_hunks: list[str] = [] + for idx, hunk in enumerate(hunks, 1): + print(f"\nHunk #{idx}:") + print(pretty_diff(hunk)) + print(f"Comment: {comment.body}") + if prompt_yes_no(f" → Is hunk #{idx} related to the comment?"): + if prompt_yes_no( + f" → Do you want to edit this hunk?", default=False + ): + new_hunk = edit_hunk(hunk) + selected_hunks.append(new_hunk) + else: + selected_hunks.append(hunk) + + if len(selected_hunks) > 0: + # join back into one diff string for storage + relevant_diffs[f] = "\n".join(selected_hunks) + + if len(relevant_diffs) == 0: + print("No relevant diffs found, skipping...") + entry.metadata.selection = Selection( + comment_suggests_change=True, + diff_after_address_change=False, + good=False, + ) + break + + print("\nRelevant diffs:") + for f, d in relevant_diffs.items(): + print(f"--- {f} ---") + print(pretty_diff(d)) + else: + relevant_diffs = entry.diffs_after + + entry.diffs_after = relevant_diffs + entry.metadata.selection = Selection( + comment_suggests_change=True, + diff_after_address_change=True, + good=True, + ) + if len(relevant_diffs) > 0: + n_good += 1 + break except KeyboardInterrupt: print("\nInterrupted.") finally: @@ -126,10 +244,19 @@ if __name__ == "__main__": parser.add_argument( "-m", "--mode", - # type=lambda x: ValidationMode(x), - # choices=[mode.value for mode in ValidationMode], + choices=[mode.value for mode in ValidationMode], default='comment', help="Validation mode: 'comment' to only check if comments suggest changes, 'refinement' to check both comment suggestions and implementation. Default is 'comment'", ) + parser.add_argument( + "--check-diff-relevance", + action="store_true", + help="Check if each diff is related to the comment before asking if it implements the change", + ) args = parser.parse_args() - main(args.dataset, overwrite=args.overwrite, validation_mode=args.mode) + main( + args.dataset, + overwrite=args.overwrite, + validation_mode=ValidationMode(args.mode), + check_diff_relevance=args.check_diff_relevance, + )