refactored manual selection into smaller bits,

easier to consume
This commit is contained in:
Karma Riuk
2025-05-16 09:58:14 +02:00
parent 63c6785b4d
commit 25161c4d46

View File

@ -1,10 +1,13 @@
from dataset import Dataset, Selection from typing import Optional
from dataset import Dataset, DatasetEntry, Selection
import argparse, os, re, click import argparse, os, re, click
from enum import Enum from enum import Enum
from utils import prompt_yes_no from utils import prompt_yes_no
HUNK_HEADER_REGEX = re.compile(r'^@@ -\d+(?:,\d+)? \+\d+(?:,\d+)? @@') HUNK_HEADER_REGEX = re.compile(r'^@@ -\d+(?:,\d+)? \+\d+(?:,\d+)? @@')
# TODO: %s/print/print/g
class ValidationMode(Enum): class ValidationMode(Enum):
COMMENT = "comment" COMMENT = "comment"
@ -24,8 +27,6 @@ def bold(line: str) -> str:
def pretty_diff(after: str) -> str: def pretty_diff(after: str) -> str:
if after is None:
return ""
lines = after.splitlines() lines = after.splitlines()
pretty_lines = [] pretty_lines = []
for line in lines: for line in lines:
@ -75,6 +76,78 @@ def edit_hunk(hunk: str) -> str:
return hunk return hunk
def display_header(i: int, total: int, n_good: int):
cols = os.get_terminal_size().columns
print("#" * cols)
print(f"# good PRs: {n_good}/{i} ({n_good / i:.2%})")
print(f"Current PR: {i}/{total} ({i / total:.2%})")
def display_pr_info(entry: DatasetEntry, i: int, total: int, n_good: int):
display_header(i, total, n_good)
pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}"
print(f"\nPull Request : {pr_url}\n")
def prompt_comment_suggestion(entry: DatasetEntry, sel: Selection | None, overwrite: bool) -> bool:
# reuse existing if available and not overwriting
if not overwrite and sel is not None and sel.comment_suggests_change is not None:
return sel.comment_suggests_change
# show comments
for c in entry.comments:
print(f"\nComment: {c.body}")
return prompt_yes_no("Do the comment suggest a change?")
def show_diffs(entry):
print("Diffs:")
for fname, diff in entry.diffs_after.items():
print(f"--- {fname} ---")
print(pretty_diff(diff) if diff else "EMPTY DIFF")
def ask_diff_relevance(entry: DatasetEntry, sel: Optional[Selection], overwrite: bool) -> bool:
if not overwrite and sel is not None and sel.diff_after_address_change is not None:
return sel.diff_after_address_change
show_diffs(entry)
return prompt_yes_no(f"Are {bold('any')} of these diffs related to the comment?")
def select_relevant_hunks(diff: str, comment: str) -> list[str]:
hunks = split_into_hunks(diff)
selected = []
for idx, h in enumerate(hunks, 1):
print(f"\nHunk #{idx}:")
print(pretty_diff(h))
print(f"Comment: {comment}")
if prompt_yes_no(f"Is hunk #{idx} related?", default=False):
if prompt_yes_no("Edit this hunk?", default=False):
h = edit_hunk(h)
selected.append(h)
return selected
def refine_entry(
entry: DatasetEntry, sel: Optional[Selection], overwrite: bool, check_diff: bool
) -> bool:
diff_relevant = ask_diff_relevance(entry, sel, overwrite)
if not diff_relevant:
return False
if check_diff:
accumulated = {}
for fname, diff in entry.diffs_after.items():
if not diff:
continue
hunks = select_relevant_hunks(diff, entry.comments[0].body)
if hunks:
accumulated[fname] = "\n".join(hunks)
entry.diffs_after = accumulated
if len(accumulated):
return False
return True
def main( def main(
dataset_path: str, dataset_path: str,
output: str, output: str,
@ -95,6 +168,7 @@ def main(
"Running in REFINEMENT VALIDATION mode - checking both comment suggestions and implementation" "Running in REFINEMENT VALIDATION mode - checking both comment suggestions and implementation"
) )
total = len(entries_to_process)
try: try:
n_good = 0 n_good = 0
for i, entry in enumerate(entries_to_process, 1): for i, entry in enumerate(entries_to_process, 1):
@ -114,122 +188,27 @@ def main(
# We'll re-ask diffs if needed below # We'll re-ask diffs if needed below
# If selection exists but incomplete for this mode, proceed # If selection exists but incomplete for this mode, proceed
# Header info display_pr_info(entry, i, total, n_good)
print("#" * os.get_terminal_size().columns)
print(f"# good PRs: {n_good}/{i} ({n_good/i:.2%})")
print(f"Current PR: {i}/{len(entries_to_process)} ({i/len(entries_to_process):.2%})")
pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}"
print(f"\nPull Request : {pr_url}")
is_code_related = any(file.file.endswith('.java') for file in entry.comments) is_code_related = any(file.file.endswith('.java') for file in entry.comments)
for comment in entry.comments: suggests = prompt_comment_suggestion(entry, sel, overwrite)
print("\nComment:", comment.body)
# Comment suggestion check
if not overwrite and sel is not None and sel.comment_suggests_change is not None:
suggests = sel.comment_suggests_change
else:
suggests = prompt_yes_no("Does this comment suggest a change?")
if not suggests: if not suggests:
print("Doesn't suggest any change, skipping...") entry.metadata.selection = Selection(False, None, is_code_related)
entry.metadata.selection = Selection( continue
comment_suggests_change=False,
diff_after_address_change=None,
is_code_related=any(file.file.endswith('.java') for file in entry.comments),
)
break
if validation_mode == ValidationMode.COMMENT: if validation_mode == ValidationMode.COMMENT:
entry.metadata.selection = Selection( entry.metadata.selection = Selection(
comment_suggests_change=True, True,
diff_after_address_change=sel.diff_after_address_change sel.diff_after_address_change if sel is not None else None,
if sel is not None is_code_related,
else None,
is_code_related=is_code_related,
) )
n_good += 1 n_good += 1
break elif validation_mode == ValidationMode.REFINEMENT:
diff_relevant = refine_entry(entry, sel, overwrite, check_diff_relevance)
# REFINEMENT mode: show all diffs first entry.metadata.selection = Selection(True, diff_relevant, is_code_related)
if diff_relevant:
# Initial relevance query
if not overwrite and sel is not None and sel.diff_after_address_change is not None:
any_relevant = sel.diff_after_address_change
else:
print("Diffs:")
for f, diff in entry.diffs_after.items():
if diff is None:
print(f, "EMPTY DIFF")
continue
print(f"--- {f} ---")
print(pretty_diff(diff))
any_relevant = prompt_yes_no("Are any of these diffs related to the comment?")
if not any_relevant:
print("No diffs relevant, skipping...")
entry.metadata.selection = Selection(
comment_suggests_change=True,
diff_after_address_change=False,
is_code_related=is_code_related,
)
break
# Ask which diffs if detailed relevance requested
relevant_diffs = {}
if check_diff_relevance:
for f, diff in entry.diffs_after.items():
if diff is None:
continue
hunks = split_into_hunks(diff)
if not hunks:
continue
print(f"\n--- {f} has {len(hunks)} hunks ---")
selected_hunks: list[str] = []
for idx, hunk in enumerate(hunks, 1):
print(f"\nHunk #{idx}:")
print(pretty_diff(hunk))
print(f"Comment: {comment.body}")
if prompt_yes_no(f" → Is hunk #{idx} related to the comment?"):
if prompt_yes_no(
f" → Do you want to edit this hunk?", default=False
):
new_hunk = edit_hunk(hunk)
selected_hunks.append(new_hunk)
else:
selected_hunks.append(hunk)
if len(selected_hunks) > 0:
# join back into one diff string for storage
relevant_diffs[f] = "\n".join(selected_hunks)
if len(relevant_diffs) == 0:
print("No relevant diffs found, skipping...")
entry.metadata.selection = Selection(
comment_suggests_change=True,
diff_after_address_change=False,
is_code_related=is_code_related,
)
break
print("\nRelevant diffs:")
for f, d in relevant_diffs.items():
print(f"--- {f} ---")
print(pretty_diff(d))
else:
relevant_diffs = entry.diffs_after
entry.diffs_after = relevant_diffs
entry.metadata.selection = Selection(
comment_suggests_change=True,
diff_after_address_change=True,
is_code_related=is_code_related,
)
if len(relevant_diffs) > 0:
n_good += 1 n_good += 1
break
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nInterrupted.") print("\nInterrupted.")
finally: finally: