mirror of
https://github.com/karma-riuk/crab.git
synced 2025-07-04 21:28:12 +02:00
when selecting, we can now choose which diff hunk
is relevant and modify it if necessary
This commit is contained in:
@ -1,8 +1,10 @@
|
|||||||
from dataset import Dataset, Selection
|
from dataset import Dataset, Selection
|
||||||
import argparse, os
|
import argparse, os, re, click
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from utils import prompt_yes_no
|
from utils import prompt_yes_no
|
||||||
|
|
||||||
|
HUNK_HEADER_REGEX = re.compile(r'^@@ -\d+(?:,\d+)? \+\d+(?:,\d+)? @@')
|
||||||
|
|
||||||
|
|
||||||
class ValidationMode(Enum):
|
class ValidationMode(Enum):
|
||||||
COMMENT = "comment"
|
COMMENT = "comment"
|
||||||
@ -38,10 +40,46 @@ def pretty_diff(after: str) -> str:
|
|||||||
return "\n".join(pretty_lines)
|
return "\n".join(pretty_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def split_into_hunks(diff: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Given a unified diff string, split it into chunks, each starting with a
|
||||||
|
hunk header (“@@ -… +… @@”) and including all context lines for that hunk.
|
||||||
|
"""
|
||||||
|
if not diff:
|
||||||
|
return []
|
||||||
|
# The regex will keep the “@@ … @@” lines as the start of each hunk.
|
||||||
|
parts = re.split(r'(?m)(^@@ .*@@)', diff)
|
||||||
|
# re.split returns something like ['', header1, body1, header2, body2, …]
|
||||||
|
hunks = []
|
||||||
|
for i in range(1, len(parts), 2):
|
||||||
|
header = parts[i]
|
||||||
|
body = parts[i + 1]
|
||||||
|
hunks.append(header + body)
|
||||||
|
return hunks
|
||||||
|
|
||||||
|
|
||||||
|
def edit_hunk(hunk: str) -> str:
|
||||||
|
while True:
|
||||||
|
edited = click.edit(hunk)
|
||||||
|
if edited is None:
|
||||||
|
print("Edit aborted, keeping original hunk")
|
||||||
|
return hunk
|
||||||
|
lines = edited.splitlines()
|
||||||
|
# Validate that the hunk header remains intact
|
||||||
|
if lines and HUNK_HEADER_REGEX.match(lines[0]):
|
||||||
|
return edited
|
||||||
|
else:
|
||||||
|
print(red("Invalid hunk header! Hunk header must not be modified."))
|
||||||
|
if not prompt_yes_no("Edit again?", default=False):
|
||||||
|
print("Keeping original hunk")
|
||||||
|
return hunk
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
validation_mode: ValidationMode = ValidationMode.REFINEMENT,
|
validation_mode: ValidationMode = ValidationMode.REFINEMENT,
|
||||||
|
check_diff_relevance: bool = False,
|
||||||
):
|
):
|
||||||
dataset = Dataset.from_json(dataset_path)
|
dataset = Dataset.from_json(dataset_path)
|
||||||
|
|
||||||
@ -59,22 +97,41 @@ def main(
|
|||||||
try:
|
try:
|
||||||
n_good = 0
|
n_good = 0
|
||||||
for i, entry in enumerate(entries_to_process, 1):
|
for i, entry in enumerate(entries_to_process, 1):
|
||||||
if entry.metadata.selection and not overwrite:
|
sel = entry.metadata.selection
|
||||||
if entry.metadata.selection.good:
|
# Skip or count already processed entries if not overwriting
|
||||||
n_good += 1
|
if not overwrite and sel is not None:
|
||||||
continue # Skip already processed
|
if (
|
||||||
print("#" * os.get_terminal_size().columns)
|
validation_mode == ValidationMode.COMMENT
|
||||||
|
and sel.comment_suggests_change is not None
|
||||||
|
):
|
||||||
|
n_good += int(sel.comment_suggests_change)
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
validation_mode == ValidationMode.REFINEMENT
|
||||||
|
and sel.diff_after_address_change is not None
|
||||||
|
):
|
||||||
|
n_good += int(sel.diff_after_address_change)
|
||||||
|
# We'll re-ask diffs if needed below
|
||||||
|
# If selection exists but incomplete for this mode, proceed
|
||||||
|
|
||||||
pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}"
|
# Header info
|
||||||
|
print("#" * os.get_terminal_size().columns)
|
||||||
print(f"# good PRs: {n_good}/{i} ({n_good/i:.2%})")
|
print(f"# good PRs: {n_good}/{i} ({n_good/i:.2%})")
|
||||||
print(f"Current PR: {i}/{len(entries_to_process)} ({i/len(entries_to_process):.2%})")
|
print(f"Current PR: {i}/{len(entries_to_process)} ({i/len(entries_to_process):.2%})")
|
||||||
|
pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}"
|
||||||
print(f"\nPull Request : {pr_url}")
|
print(f"\nPull Request : {pr_url}")
|
||||||
|
|
||||||
for comment in entry.comments:
|
for comment in entry.comments:
|
||||||
print("\nComment:", comment.body)
|
print("\nComment:", comment.body)
|
||||||
change = prompt_yes_no("Does this comment suggest a change?")
|
|
||||||
|
|
||||||
if not change:
|
# Comment suggestion check
|
||||||
|
if not overwrite and sel is not None and sel.comment_suggests_change is not None:
|
||||||
|
suggests = sel.comment_suggests_change
|
||||||
|
else:
|
||||||
|
suggests = prompt_yes_no("Does this comment suggest a change?")
|
||||||
|
|
||||||
|
if not suggests:
|
||||||
|
print("Doesn't suggest any change, skipping...")
|
||||||
entry.metadata.selection = Selection(
|
entry.metadata.selection = Selection(
|
||||||
comment_suggests_change=False,
|
comment_suggests_change=False,
|
||||||
diff_after_address_change=None,
|
diff_after_address_change=None,
|
||||||
@ -83,7 +140,6 @@ def main(
|
|||||||
break
|
break
|
||||||
|
|
||||||
if validation_mode == ValidationMode.COMMENT:
|
if validation_mode == ValidationMode.COMMENT:
|
||||||
# In comment validation mode, we only check if the comment suggests a change
|
|
||||||
entry.metadata.selection = Selection(
|
entry.metadata.selection = Selection(
|
||||||
comment_suggests_change=True,
|
comment_suggests_change=True,
|
||||||
diff_after_address_change=None,
|
diff_after_address_change=None,
|
||||||
@ -91,24 +147,86 @@ def main(
|
|||||||
)
|
)
|
||||||
n_good += 1
|
n_good += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# REFINEMENT mode: show all diffs first
|
||||||
|
|
||||||
|
# Initial relevance query
|
||||||
|
if not overwrite and sel is not None and sel.diff_after_address_change is not None:
|
||||||
|
any_relevant = sel.diff_after_address_change
|
||||||
else:
|
else:
|
||||||
# In refinement validation mode, we also check if the diff implements the change
|
print("Diffs:")
|
||||||
for file, diff in entry.diffs_after.items():
|
for f, diff in entry.diffs_after.items():
|
||||||
if diff is None:
|
if diff is None:
|
||||||
print(file, "EMPTY DIFF")
|
print(f, "EMPTY DIFF")
|
||||||
continue
|
continue
|
||||||
print(file, pretty_diff(diff))
|
print(f"--- {f} ---")
|
||||||
|
print(pretty_diff(diff))
|
||||||
applied = prompt_yes_no("Does this diff implement the change suggested?")
|
any_relevant = prompt_yes_no("Are any of these diffs related to the comment?")
|
||||||
|
|
||||||
|
if not any_relevant:
|
||||||
|
print("No diffs relevant, skipping...")
|
||||||
entry.metadata.selection = Selection(
|
entry.metadata.selection = Selection(
|
||||||
comment_suggests_change=True,
|
comment_suggests_change=True,
|
||||||
diff_after_address_change=applied,
|
diff_after_address_change=False,
|
||||||
good=applied,
|
good=False,
|
||||||
)
|
)
|
||||||
if applied:
|
break
|
||||||
n_good += 1
|
|
||||||
|
|
||||||
|
# Ask which diffs if detailed relevance requested
|
||||||
|
relevant_diffs = {}
|
||||||
|
if check_diff_relevance:
|
||||||
|
for f, diff in entry.diffs_after.items():
|
||||||
|
if diff is None:
|
||||||
|
continue
|
||||||
|
hunks = split_into_hunks(diff)
|
||||||
|
if not hunks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n--- {f} has {len(hunks)} hunks ---")
|
||||||
|
|
||||||
|
selected_hunks: list[str] = []
|
||||||
|
for idx, hunk in enumerate(hunks, 1):
|
||||||
|
print(f"\nHunk #{idx}:")
|
||||||
|
print(pretty_diff(hunk))
|
||||||
|
print(f"Comment: {comment.body}")
|
||||||
|
if prompt_yes_no(f" → Is hunk #{idx} related to the comment?"):
|
||||||
|
if prompt_yes_no(
|
||||||
|
f" → Do you want to edit this hunk?", default=False
|
||||||
|
):
|
||||||
|
new_hunk = edit_hunk(hunk)
|
||||||
|
selected_hunks.append(new_hunk)
|
||||||
|
else:
|
||||||
|
selected_hunks.append(hunk)
|
||||||
|
|
||||||
|
if len(selected_hunks) > 0:
|
||||||
|
# join back into one diff string for storage
|
||||||
|
relevant_diffs[f] = "\n".join(selected_hunks)
|
||||||
|
|
||||||
|
if len(relevant_diffs) == 0:
|
||||||
|
print("No relevant diffs found, skipping...")
|
||||||
|
entry.metadata.selection = Selection(
|
||||||
|
comment_suggests_change=True,
|
||||||
|
diff_after_address_change=False,
|
||||||
|
good=False,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\nRelevant diffs:")
|
||||||
|
for f, d in relevant_diffs.items():
|
||||||
|
print(f"--- {f} ---")
|
||||||
|
print(pretty_diff(d))
|
||||||
|
else:
|
||||||
|
relevant_diffs = entry.diffs_after
|
||||||
|
|
||||||
|
entry.diffs_after = relevant_diffs
|
||||||
|
entry.metadata.selection = Selection(
|
||||||
|
comment_suggests_change=True,
|
||||||
|
diff_after_address_change=True,
|
||||||
|
good=True,
|
||||||
|
)
|
||||||
|
if len(relevant_diffs) > 0:
|
||||||
|
n_good += 1
|
||||||
|
break
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nInterrupted.")
|
print("\nInterrupted.")
|
||||||
finally:
|
finally:
|
||||||
@ -126,10 +244,19 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-m",
|
"-m",
|
||||||
"--mode",
|
"--mode",
|
||||||
# type=lambda x: ValidationMode(x),
|
choices=[mode.value for mode in ValidationMode],
|
||||||
# choices=[mode.value for mode in ValidationMode],
|
|
||||||
default='comment',
|
default='comment',
|
||||||
help="Validation mode: 'comment' to only check if comments suggest changes, 'refinement' to check both comment suggestions and implementation. Default is 'comment'",
|
help="Validation mode: 'comment' to only check if comments suggest changes, 'refinement' to check both comment suggestions and implementation. Default is 'comment'",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--check-diff-relevance",
|
||||||
|
action="store_true",
|
||||||
|
help="Check if each diff is related to the comment before asking if it implements the change",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args.dataset, overwrite=args.overwrite, validation_mode=args.mode)
|
main(
|
||||||
|
args.dataset,
|
||||||
|
overwrite=args.overwrite,
|
||||||
|
validation_mode=ValidationMode(args.mode),
|
||||||
|
check_diff_relevance=args.check_diff_relevance,
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user