when selecting, we can now choose which diff hunk

is relevant and modify it if necessary
This commit is contained in:
Karma Riuk
2025-05-07 10:39:21 +02:00
parent 959184b2a8
commit a701dc236c

View File

@ -1,8 +1,10 @@
from dataset import Dataset, Selection from dataset import Dataset, Selection
import argparse, os import argparse, os, re, click
from enum import Enum from enum import Enum
from utils import prompt_yes_no from utils import prompt_yes_no
HUNK_HEADER_REGEX = re.compile(r'^@@ -\d+(?:,\d+)? \+\d+(?:,\d+)? @@')
class ValidationMode(Enum): class ValidationMode(Enum):
COMMENT = "comment" COMMENT = "comment"
@ -38,10 +40,46 @@ def pretty_diff(after: str) -> str:
return "\n".join(pretty_lines) return "\n".join(pretty_lines)
def split_into_hunks(diff: str) -> list[str]:
"""
Given a unified diff string, split it into chunks, each starting with a
hunk header (“@@ -… +… @@”) and including all context lines for that hunk.
"""
if not diff:
return []
# The regex will keep the “@@ … @@” lines as the start of each hunk.
parts = re.split(r'(?m)(^@@ .*@@)', diff)
# re.split returns something like ['', header1, body1, header2, body2, …]
hunks = []
for i in range(1, len(parts), 2):
header = parts[i]
body = parts[i + 1]
hunks.append(header + body)
return hunks
def edit_hunk(hunk: str) -> str:
while True:
edited = click.edit(hunk)
if edited is None:
print("Edit aborted, keeping original hunk")
return hunk
lines = edited.splitlines()
# Validate that the hunk header remains intact
if lines and HUNK_HEADER_REGEX.match(lines[0]):
return edited
else:
print(red("Invalid hunk header! Hunk header must not be modified."))
if not prompt_yes_no("Edit again?", default=False):
print("Keeping original hunk")
return hunk
def main( def main(
dataset_path: str, dataset_path: str,
overwrite: bool = False, overwrite: bool = False,
validation_mode: ValidationMode = ValidationMode.REFINEMENT, validation_mode: ValidationMode = ValidationMode.REFINEMENT,
check_diff_relevance: bool = False,
): ):
dataset = Dataset.from_json(dataset_path) dataset = Dataset.from_json(dataset_path)
@ -59,22 +97,41 @@ def main(
try: try:
n_good = 0 n_good = 0
for i, entry in enumerate(entries_to_process, 1): for i, entry in enumerate(entries_to_process, 1):
if entry.metadata.selection and not overwrite: sel = entry.metadata.selection
if entry.metadata.selection.good: # Skip or count already processed entries if not overwriting
n_good += 1 if not overwrite and sel is not None:
continue # Skip already processed if (
print("#" * os.get_terminal_size().columns) validation_mode == ValidationMode.COMMENT
and sel.comment_suggests_change is not None
):
n_good += int(sel.comment_suggests_change)
continue
if (
validation_mode == ValidationMode.REFINEMENT
and sel.diff_after_address_change is not None
):
n_good += int(sel.diff_after_address_change)
# We'll re-ask diffs if needed below
# If selection exists but incomplete for this mode, proceed
pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}" # Header info
print("#" * os.get_terminal_size().columns)
print(f"# good PRs: {n_good}/{i} ({n_good/i:.2%})") print(f"# good PRs: {n_good}/{i} ({n_good/i:.2%})")
print(f"Current PR: {i}/{len(entries_to_process)} ({i/len(entries_to_process):.2%})") print(f"Current PR: {i}/{len(entries_to_process)} ({i/len(entries_to_process):.2%})")
pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}"
print(f"\nPull Request : {pr_url}") print(f"\nPull Request : {pr_url}")
for comment in entry.comments: for comment in entry.comments:
print("\nComment:", comment.body) print("\nComment:", comment.body)
change = prompt_yes_no("Does this comment suggest a change?")
if not change: # Comment suggestion check
if not overwrite and sel is not None and sel.comment_suggests_change is not None:
suggests = sel.comment_suggests_change
else:
suggests = prompt_yes_no("Does this comment suggest a change?")
if not suggests:
print("Doesn't suggest any change, skipping...")
entry.metadata.selection = Selection( entry.metadata.selection = Selection(
comment_suggests_change=False, comment_suggests_change=False,
diff_after_address_change=None, diff_after_address_change=None,
@ -83,7 +140,6 @@ def main(
break break
if validation_mode == ValidationMode.COMMENT: if validation_mode == ValidationMode.COMMENT:
# In comment validation mode, we only check if the comment suggests a change
entry.metadata.selection = Selection( entry.metadata.selection = Selection(
comment_suggests_change=True, comment_suggests_change=True,
diff_after_address_change=None, diff_after_address_change=None,
@ -91,24 +147,86 @@ def main(
) )
n_good += 1 n_good += 1
break break
# REFINEMENT mode: show all diffs first
# Initial relevance query
if not overwrite and sel is not None and sel.diff_after_address_change is not None:
any_relevant = sel.diff_after_address_change
else: else:
# In refinement validation mode, we also check if the diff implements the change print("Diffs:")
for file, diff in entry.diffs_after.items(): for f, diff in entry.diffs_after.items():
if diff is None: if diff is None:
print(file, "EMPTY DIFF") print(f, "EMPTY DIFF")
continue continue
print(file, pretty_diff(diff)) print(f"--- {f} ---")
print(pretty_diff(diff))
applied = prompt_yes_no("Does this diff implement the change suggested?") any_relevant = prompt_yes_no("Are any of these diffs related to the comment?")
if not any_relevant:
print("No diffs relevant, skipping...")
entry.metadata.selection = Selection( entry.metadata.selection = Selection(
comment_suggests_change=True, comment_suggests_change=True,
diff_after_address_change=applied, diff_after_address_change=False,
good=applied, good=False,
) )
if applied: break
n_good += 1
# Ask which diffs if detailed relevance requested
relevant_diffs = {}
if check_diff_relevance:
for f, diff in entry.diffs_after.items():
if diff is None:
continue
hunks = split_into_hunks(diff)
if not hunks:
continue
print(f"\n--- {f} has {len(hunks)} hunks ---")
selected_hunks: list[str] = []
for idx, hunk in enumerate(hunks, 1):
print(f"\nHunk #{idx}:")
print(pretty_diff(hunk))
print(f"Comment: {comment.body}")
if prompt_yes_no(f" → Is hunk #{idx} related to the comment?"):
if prompt_yes_no(
f" → Do you want to edit this hunk?", default=False
):
new_hunk = edit_hunk(hunk)
selected_hunks.append(new_hunk)
else:
selected_hunks.append(hunk)
if len(selected_hunks) > 0:
# join back into one diff string for storage
relevant_diffs[f] = "\n".join(selected_hunks)
if len(relevant_diffs) == 0:
print("No relevant diffs found, skipping...")
entry.metadata.selection = Selection(
comment_suggests_change=True,
diff_after_address_change=False,
good=False,
)
break
print("\nRelevant diffs:")
for f, d in relevant_diffs.items():
print(f"--- {f} ---")
print(pretty_diff(d))
else:
relevant_diffs = entry.diffs_after
entry.diffs_after = relevant_diffs
entry.metadata.selection = Selection(
comment_suggests_change=True,
diff_after_address_change=True,
good=True,
)
if len(relevant_diffs) > 0:
n_good += 1
break
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nInterrupted.") print("\nInterrupted.")
finally: finally:
@ -126,10 +244,19 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"-m", "-m",
"--mode", "--mode",
# type=lambda x: ValidationMode(x), choices=[mode.value for mode in ValidationMode],
# choices=[mode.value for mode in ValidationMode],
default='comment', default='comment',
help="Validation mode: 'comment' to only check if comments suggest changes, 'refinement' to check both comment suggestions and implementation. Default is 'comment'", help="Validation mode: 'comment' to only check if comments suggest changes, 'refinement' to check both comment suggestions and implementation. Default is 'comment'",
) )
parser.add_argument(
"--check-diff-relevance",
action="store_true",
help="Check if each diff is related to the comment before asking if it implements the change",
)
args = parser.parse_args() args = parser.parse_args()
main(args.dataset, overwrite=args.overwrite, validation_mode=args.mode) main(
args.dataset,
overwrite=args.overwrite,
validation_mode=ValidationMode(args.mode),
check_diff_relevance=args.check_diff_relevance,
)