mirror of
https://github.com/karma-riuk/crab.git
synced 2025-07-04 21:28:12 +02:00
254 lines
8.5 KiB
Python
254 lines
8.5 KiB
Python
from typing import Optional
|
|
from dataset import Dataset, DatasetEntry, Selection
|
|
import argparse, os, re, click
|
|
from enum import Enum
|
|
from utils import EnumChoicesAction, prompt_yes_no
|
|
|
|
HUNK_HEADER_REGEX = re.compile(r'^@@ -\d+(?:,\d+)? \+\d+(?:,\d+)? @@')
|
|
|
|
|
|
class ValidationMode(Enum):
|
|
COMMENT = "comment"
|
|
REFINEMENT = "refinement"
|
|
|
|
|
|
def green(line: str) -> str:
|
|
return f"\033[32m{line}\033[0m"
|
|
|
|
|
|
def red(line: str) -> str:
|
|
return f"\033[31m{line}\033[0m"
|
|
|
|
|
|
def bold(line: str) -> str:
|
|
return f"\033[1m{line}\033[0m"
|
|
|
|
|
|
def pretty_diff(after: str) -> str:
|
|
lines = after.splitlines()
|
|
pretty_lines = []
|
|
for line in lines:
|
|
if line.startswith("+"):
|
|
pretty_lines.append(green(line))
|
|
elif line.startswith("-"):
|
|
pretty_lines.append(red(line))
|
|
elif line.startswith("@@"):
|
|
pretty_lines.append(bold(line))
|
|
else:
|
|
pretty_lines.append(line)
|
|
return "\n".join(pretty_lines)
|
|
|
|
|
|
def split_into_hunks(diff: str) -> list[str]:
|
|
"""
|
|
Given a unified diff string, split it into chunks, each starting with a
|
|
hunk header (“@@ -… +… @@”) and including all context lines for that hunk.
|
|
"""
|
|
if not diff:
|
|
return []
|
|
# The regex will keep the “@@ … @@” lines as the start of each hunk.
|
|
parts = re.split(r'(?m)(^@@ .*@@)', diff)
|
|
# re.split returns something like ['', header1, body1, header2, body2, …]
|
|
hunks = []
|
|
for i in range(1, len(parts), 2):
|
|
header = parts[i]
|
|
body = parts[i + 1]
|
|
hunks.append(header + body)
|
|
return hunks
|
|
|
|
|
|
def edit_hunk(hunk: str) -> str:
|
|
while True:
|
|
edited = click.edit(hunk)
|
|
if edited is None:
|
|
print("Edit aborted, keeping original hunk")
|
|
return hunk
|
|
lines = edited.splitlines()
|
|
# Validate that the hunk header remains intact
|
|
if lines and HUNK_HEADER_REGEX.match(lines[0]):
|
|
return edited
|
|
else:
|
|
print(red("Invalid hunk header! Hunk header must not be modified."))
|
|
if not prompt_yes_no("Edit again?", default=False):
|
|
print("Keeping original hunk")
|
|
return hunk
|
|
|
|
|
|
def display_header(i: int, total: int, n_good: int):
|
|
cols = os.get_terminal_size().columns
|
|
print("#" * cols)
|
|
print(f"# good PRs: {n_good}/{i} ({n_good / i:.2%})")
|
|
print(f"Current PR: {i}/{total} ({i / total:.2%})")
|
|
|
|
|
|
def display_pr_info(entry: DatasetEntry, i: int, total: int, n_good: int):
|
|
display_header(i, total, n_good)
|
|
pr_url = f"https://github.com/{entry.metadata.repo}/pull/{entry.metadata.pr_number}"
|
|
print(f"\nPull Request : {pr_url}\n")
|
|
|
|
|
|
def prompt_comment_suggestion(
|
|
entry: DatasetEntry, sel: Optional[Selection], overwrite: bool
|
|
) -> bool:
|
|
# reuse existing if available and not overwriting
|
|
if not overwrite and sel is not None and sel.comment_suggests_change is not None:
|
|
return sel.comment_suggests_change
|
|
# show comments
|
|
for c in entry.comments:
|
|
print(f"\nComment: {c.body}")
|
|
return prompt_yes_no("Do the comment suggest a change?")
|
|
|
|
|
|
def show_diffs(entry):
|
|
print("Diffs:")
|
|
for fname, diff in entry.diffs_after.items():
|
|
print(f"--- {fname} ---")
|
|
print(pretty_diff(diff) if diff else "EMPTY DIFF")
|
|
|
|
|
|
def ask_diff_relevance(entry: DatasetEntry, sel: Optional[Selection], overwrite: bool) -> bool:
|
|
if not overwrite and sel is not None and sel.diff_after_address_change is not None:
|
|
return sel.diff_after_address_change
|
|
show_diffs(entry)
|
|
return prompt_yes_no(f"Are {bold('any')} of these diffs related to the comment?")
|
|
|
|
|
|
def select_relevant_hunks(diff: str, comment: str) -> list[str]:
|
|
hunks = split_into_hunks(diff)
|
|
selected = []
|
|
for idx, h in enumerate(hunks, 1):
|
|
print(f"\nHunk #{idx}:")
|
|
print(pretty_diff(h))
|
|
print(f"Comment: {comment}")
|
|
if prompt_yes_no(f"Is hunk #{idx} related?", default=False):
|
|
if prompt_yes_no("Edit this hunk?", default=False):
|
|
h = edit_hunk(h)
|
|
selected.append(h)
|
|
return selected
|
|
|
|
|
|
def refine_entry(
|
|
entry: DatasetEntry, sel: Optional[Selection], overwrite: bool, check_diff: bool
|
|
) -> bool:
|
|
diff_relevant = ask_diff_relevance(entry, sel, overwrite)
|
|
if not diff_relevant:
|
|
return False
|
|
|
|
if check_diff:
|
|
accumulated = {}
|
|
for fname, diff in entry.diffs_after.items():
|
|
if not diff:
|
|
continue
|
|
hunks = select_relevant_hunks(diff, entry.comments[0].body)
|
|
if hunks:
|
|
accumulated[fname] = "\n".join(hunks)
|
|
entry.diffs_after = accumulated
|
|
if len(accumulated):
|
|
return False
|
|
return True
|
|
|
|
|
|
def main(
|
|
dataset_path: str,
|
|
output: str,
|
|
overwrite: bool = False,
|
|
validation_mode: ValidationMode = ValidationMode.REFINEMENT,
|
|
check_diff_relevance: bool = False,
|
|
):
|
|
dataset = Dataset.from_json(dataset_path)
|
|
|
|
if validation_mode == ValidationMode.COMMENT:
|
|
# For comment validation, process all entries
|
|
entries_to_process = dataset.entries
|
|
print("Running in COMMENT VALIDATION mode - only checking if comments suggest changes")
|
|
else:
|
|
# For refinement validation, only process successful entries
|
|
entries_to_process = [entry for entry in dataset.entries if entry.metadata.successful]
|
|
print(
|
|
"Running in REFINEMENT VALIDATION mode - checking both comment suggestions and implementation"
|
|
)
|
|
|
|
total = len(entries_to_process)
|
|
try:
|
|
n_good = 0
|
|
for i, entry in enumerate(entries_to_process, 1):
|
|
sel = entry.metadata.selection
|
|
# Skip or count already processed entries if not overwriting
|
|
if not overwrite and sel is not None:
|
|
if (
|
|
validation_mode == ValidationMode.COMMENT
|
|
and sel.comment_suggests_change is not None
|
|
):
|
|
n_good += int(sel.comment_suggests_change)
|
|
if (
|
|
validation_mode == ValidationMode.REFINEMENT
|
|
and sel.diff_after_address_change is not None
|
|
):
|
|
n_good += int(sel.diff_after_address_change)
|
|
# We'll re-ask diffs if needed below
|
|
# If selection exists but incomplete for this mode, proceed
|
|
|
|
display_pr_info(entry, i, total, n_good)
|
|
|
|
suggests = prompt_comment_suggestion(entry, sel, overwrite)
|
|
|
|
if not suggests:
|
|
entry.metadata.selection = Selection(False, None)
|
|
continue
|
|
|
|
if validation_mode == ValidationMode.COMMENT:
|
|
entry.metadata.selection = Selection(
|
|
True,
|
|
sel.diff_after_address_change if sel is not None else None,
|
|
)
|
|
n_good += 1
|
|
elif validation_mode == ValidationMode.REFINEMENT:
|
|
diff_relevant = refine_entry(entry, sel, overwrite, check_diff_relevance)
|
|
entry.metadata.selection = Selection(True, diff_relevant)
|
|
if diff_relevant:
|
|
n_good += 1
|
|
except KeyboardInterrupt:
|
|
print("\nInterrupted.")
|
|
finally:
|
|
print(f"Saving dataset to {output}...", end=" ", flush=True)
|
|
dataset.to_json(output)
|
|
print("Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Manual selection of dataset")
|
|
parser.add_argument("dataset", type=str, help="Path to the dataset file")
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output",
|
|
required=True,
|
|
type=str,
|
|
help="The path to the resulting dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--overwrite",
|
|
action="store_true",
|
|
help="Re-evaluate existing selections",
|
|
)
|
|
parser.add_argument(
|
|
"-m",
|
|
"--mode",
|
|
type=ValidationMode,
|
|
default=ValidationMode.COMMENT,
|
|
action=EnumChoicesAction,
|
|
help=f"Validation mode: '{ValidationMode.COMMENT.value}' to only check if comments suggest changes, '{ValidationMode.REFINEMENT.value}' to check both comment suggestions and implementation. Default is '{ValidationMode.COMMENT.value}'",
|
|
)
|
|
parser.add_argument(
|
|
"--check-diff-relevance",
|
|
action="store_true",
|
|
help="Check if each diff is related to the comment before asking if it implements the change",
|
|
)
|
|
args = parser.parse_args()
|
|
main(
|
|
args.dataset,
|
|
args.output,
|
|
overwrite=args.overwrite,
|
|
validation_mode=args.mode,
|
|
check_diff_relevance=args.check_diff_relevance,
|
|
)
|