From 25161c4d463263c07ebfd7d036e572c1b339a78a Mon Sep 17 00:00:00 2001 From: Karma Riuk Date: Fri, 16 May 2025 09:58:14 +0200 Subject: [PATCH] refactored manual selection into smaller bits, easier to consume --- manual_selection.py | 203 ++++++++++++++++++++------------------------ 1 file changed, 91 insertions(+), 112 deletions(-) diff --git a/manual_selection.py b/manual_selection.py index f228f41..694ca2d 100644 --- a/manual_selection.py +++ b/manual_selection.py @@ -1,10 +1,13 @@ -from dataset import Dataset, Selection +from typing import Optional +from dataset import Dataset, DatasetEntry, Selection import argparse, os, re, click from enum import Enum from utils import prompt_yes_no HUNK_HEADER_REGEX = re.compile(r'^@@ -\d+(?:,\d+)? \+\d+(?:,\d+)? @@') +# TODO: %s/print/print/g + class ValidationMode(Enum): COMMENT = "comment" @@ -24,8 +27,6 @@ def bold(line: str) -> str: def pretty_diff(after: str) -> str: - if after is None: - return "" lines = after.splitlines() pretty_lines = [] for line in lines: @@ -75,6 +76,78 @@ def edit_hunk(hunk: str) -> str: return hunk +def display_header(i: int, total: int, n_good: int): + cols = os.get_terminal_size().columns + print("#" * cols) + print(f"# good PRs: {n_good}/{i} ({n_good / i:.2%})") + print(f"Current PR: {i}/{total} ({i / total:.2%})") + + +def display_pr_info(entry: DatasetEntry, i: int, total: int, n_good: int): + display_header(i, total, n_good) + pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}" + print(f"\nPull Request : {pr_url}\n") + + +def prompt_comment_suggestion(entry: DatasetEntry, sel: Selection | None, overwrite: bool) -> bool: + # reuse existing if available and not overwriting + if not overwrite and sel is not None and sel.comment_suggests_change is not None: + return sel.comment_suggests_change + # show comments + for c in entry.comments: + print(f"\nComment: {c.body}") + return prompt_yes_no("Do the comment suggest a change?") + + +def show_diffs(entry): + print("Diffs:") + for fname, diff in entry.diffs_after.items(): + print(f"--- {fname} ---") + print(pretty_diff(diff) if diff else "EMPTY DIFF") + + +def ask_diff_relevance(entry: DatasetEntry, sel: Optional[Selection], overwrite: bool) -> bool: + if not overwrite and sel is not None and sel.diff_after_address_change is not None: + return sel.diff_after_address_change + show_diffs(entry) + return prompt_yes_no(f"Are {bold('any')} of these diffs related to the comment?") + + +def select_relevant_hunks(diff: str, comment: str) -> list[str]: + hunks = split_into_hunks(diff) + selected = [] + for idx, h in enumerate(hunks, 1): + print(f"\nHunk #{idx}:") + print(pretty_diff(h)) + print(f"Comment: {comment}") + if prompt_yes_no(f"Is hunk #{idx} related?", default=False): + if prompt_yes_no("Edit this hunk?", default=False): + h = edit_hunk(h) + selected.append(h) + return selected + + +def refine_entry( + entry: DatasetEntry, sel: Optional[Selection], overwrite: bool, check_diff: bool +) -> bool: + diff_relevant = ask_diff_relevance(entry, sel, overwrite) + if not diff_relevant: + return False + + if check_diff: + accumulated = {} + for fname, diff in entry.diffs_after.items(): + if not diff: + continue + hunks = select_relevant_hunks(diff, entry.comments[0].body) + if hunks: + accumulated[fname] = "\n".join(hunks) + entry.diffs_after = accumulated + if len(accumulated): + return False + return True + + def main( dataset_path: str, output: str, @@ -95,6 +168,7 @@ def main( "Running in REFINEMENT VALIDATION mode - checking both comment suggestions and implementation" ) + total = len(entries_to_process) try: n_good = 0 for i, entry in enumerate(entries_to_process, 1): @@ -114,122 +188,27 @@ def main( # We'll re-ask diffs if needed below # If selection exists but incomplete for this mode, proceed - # Header info - print("#" * os.get_terminal_size().columns) - print(f"# good PRs: {n_good}/{i} ({n_good/i:.2%})") - print(f"Current PR: {i}/{len(entries_to_process)} ({i/len(entries_to_process):.2%})") - pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}" - print(f"\nPull Request : {pr_url}") + display_pr_info(entry, i, total, n_good) is_code_related = any(file.file.endswith('.java') for file in entry.comments) - for comment in entry.comments: - print("\nComment:", comment.body) + suggests = prompt_comment_suggestion(entry, sel, overwrite) - # Comment suggestion check - if not overwrite and sel is not None and sel.comment_suggests_change is not None: - suggests = sel.comment_suggests_change - else: - suggests = prompt_yes_no("Does this comment suggest a change?") + if not suggests: + entry.metadata.selection = Selection(False, None, is_code_related) + continue - if not suggests: - print("Doesn't suggest any change, skipping...") - entry.metadata.selection = Selection( - comment_suggests_change=False, - diff_after_address_change=None, - is_code_related=any(file.file.endswith('.java') for file in entry.comments), - ) - break - - if validation_mode == ValidationMode.COMMENT: - entry.metadata.selection = Selection( - comment_suggests_change=True, - diff_after_address_change=sel.diff_after_address_change - if sel is not None - else None, - is_code_related=is_code_related, - ) - n_good += 1 - break - - # REFINEMENT mode: show all diffs first - - # Initial relevance query - if not overwrite and sel is not None and sel.diff_after_address_change is not None: - any_relevant = sel.diff_after_address_change - else: - print("Diffs:") - for f, diff in entry.diffs_after.items(): - if diff is None: - print(f, "EMPTY DIFF") - continue - print(f"--- {f} ---") - print(pretty_diff(diff)) - any_relevant = prompt_yes_no("Are any of these diffs related to the comment?") - - if not any_relevant: - print("No diffs relevant, skipping...") - entry.metadata.selection = Selection( - comment_suggests_change=True, - diff_after_address_change=False, - is_code_related=is_code_related, - ) - break - - # Ask which diffs if detailed relevance requested - relevant_diffs = {} - if check_diff_relevance: - for f, diff in entry.diffs_after.items(): - if diff is None: - continue - hunks = split_into_hunks(diff) - if not hunks: - continue - - print(f"\n--- {f} has {len(hunks)} hunks ---") - - selected_hunks: list[str] = [] - for idx, hunk in enumerate(hunks, 1): - print(f"\nHunk #{idx}:") - print(pretty_diff(hunk)) - print(f"Comment: {comment.body}") - if prompt_yes_no(f" → Is hunk #{idx} related to the comment?"): - if prompt_yes_no( - f" → Do you want to edit this hunk?", default=False - ): - new_hunk = edit_hunk(hunk) - selected_hunks.append(new_hunk) - else: - selected_hunks.append(hunk) - - if len(selected_hunks) > 0: - # join back into one diff string for storage - relevant_diffs[f] = "\n".join(selected_hunks) - - if len(relevant_diffs) == 0: - print("No relevant diffs found, skipping...") - entry.metadata.selection = Selection( - comment_suggests_change=True, - diff_after_address_change=False, - is_code_related=is_code_related, - ) - break - - print("\nRelevant diffs:") - for f, d in relevant_diffs.items(): - print(f"--- {f} ---") - print(pretty_diff(d)) - else: - relevant_diffs = entry.diffs_after - - entry.diffs_after = relevant_diffs + if validation_mode == ValidationMode.COMMENT: entry.metadata.selection = Selection( - comment_suggests_change=True, - diff_after_address_change=True, - is_code_related=is_code_related, + True, + sel.diff_after_address_change if sel is not None else None, + is_code_related, ) - if len(relevant_diffs) > 0: + n_good += 1 + elif validation_mode == ValidationMode.REFINEMENT: + diff_relevant = refine_entry(entry, sel, overwrite, check_diff_relevance) + entry.metadata.selection = Selection(True, diff_relevant, is_code_related) + if diff_relevant: n_good += 1 - break except KeyboardInterrupt: print("\nInterrupted.") finally: