mirror of
https://github.com/karma-riuk/crab.git
synced 2025-07-04 21:28:12 +02:00
we can now generate the datasets to be served to users
This commit is contained in:
122
dataset.py
122
dataset.py
@ -1,6 +1,8 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional
|
from enum import Enum
|
||||||
import json
|
from typing import Any, Dict, List, Optional
|
||||||
|
import json, argparse, os
|
||||||
|
from utils import prompt_yes_no
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -45,6 +47,42 @@ class DatasetEntry:
|
|||||||
diffs_after: Dict[str, str] # filename -> diff, changes after the comment
|
diffs_after: Dict[str, str] # filename -> diff, changes after the comment
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommentGenEntry:
|
||||||
|
id: int
|
||||||
|
files: Dict[str, str] # filename -> file content
|
||||||
|
diffs: Dict[str, str] # filename -> diff, diffs between the opening of the PR and the comment
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_entry(entry: DatasetEntry, id: int) -> "CommentGenEntry":
|
||||||
|
return CommentGenEntry(
|
||||||
|
id=id,
|
||||||
|
files={fname: fdata.content_before_pr for fname, fdata in entry.files.items()},
|
||||||
|
diffs=entry.diffs_before,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CodeRefinementEntry:
|
||||||
|
id: int
|
||||||
|
files: Dict[str, str] # filename -> file content
|
||||||
|
diffs: Dict[str, str] # filename -> diff, diffs between the opening of the PR and the comment
|
||||||
|
comments: List[Comment]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_entry(entry: DatasetEntry, id: int) -> "CodeRefinementEntry":
|
||||||
|
return CodeRefinementEntry(
|
||||||
|
id=id,
|
||||||
|
files={fname: fdata.content_before_pr for fname, fdata in entry.files.items()},
|
||||||
|
diffs=entry.diffs_before,
|
||||||
|
comments=entry.comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
class OutputType(Enum):
|
||||||
|
FULL = "full"
|
||||||
|
CODE_REFINEMENT = "code_refinement"
|
||||||
|
COMMENT_GEN = "comment_gen"
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@dataclass
|
@dataclass
|
||||||
class Dataset:
|
class Dataset:
|
||||||
@ -53,10 +91,47 @@ class Dataset:
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return sum(1 for entry in self.entries if entry.metadata.successful)
|
return sum(1 for entry in self.entries if entry.metadata.successful)
|
||||||
|
|
||||||
def to_json(self, filename: str):
|
def to_json(self, filename: str, type_: OutputType = OutputType.FULL) -> None:
|
||||||
"""Serialize the dataset to a JSON file"""
|
"""Serialize the dataset to a JSON file"""
|
||||||
|
|
||||||
|
entries_to_dump = self.entries
|
||||||
|
|
||||||
|
if type_ == OutputType.COMMENT_GEN:
|
||||||
|
entries_to_dump = [
|
||||||
|
entry
|
||||||
|
for entry in self.entries
|
||||||
|
if entry.metadata.selection and entry.metadata.selection.comment_suggests_change
|
||||||
|
]
|
||||||
|
elif type_ == OutputType.CODE_REFINEMENT:
|
||||||
|
entries_to_dump = [
|
||||||
|
entry
|
||||||
|
for entry in self.entries
|
||||||
|
if entry.metadata.selection and entry.metadata.selection.diff_after_address_change
|
||||||
|
]
|
||||||
|
|
||||||
|
entry_counter = -1
|
||||||
|
to_dump = Dataset(entries=entries_to_dump)
|
||||||
|
|
||||||
|
def transform_entry(entry: DatasetEntry | Dataset | Any) -> dict | list:
|
||||||
|
if not isinstance(entry, (DatasetEntry, Dataset)):
|
||||||
|
return entry.__dict__
|
||||||
|
|
||||||
|
if type_ == OutputType.FULL:
|
||||||
|
return entry.__dict__
|
||||||
|
|
||||||
|
if isinstance(entry, Dataset):
|
||||||
|
return entry.entries
|
||||||
|
|
||||||
|
nonlocal entry_counter
|
||||||
|
entry_counter += 1
|
||||||
|
if type_ == OutputType.COMMENT_GEN:
|
||||||
|
return CommentGenEntry.from_entry(entry, entry_counter).__dict__
|
||||||
|
|
||||||
|
if type_ == OutputType.CODE_REFINEMENT:
|
||||||
|
return CodeRefinementEntry.from_entry(entry, entry_counter).__dict__
|
||||||
|
|
||||||
with open(filename, "w", encoding="utf-8") as f:
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
json.dump(self, f, default=lambda o: o.__dict__, indent=4)
|
json.dump(to_dump, f, default=transform_entry, indent=4)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset":
|
def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset":
|
||||||
@ -93,3 +168,42 @@ class Dataset:
|
|||||||
entries.append(entry)
|
entries.append(entry)
|
||||||
|
|
||||||
return Dataset(entries=entries)
|
return Dataset(entries=entries)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Dataset class")
|
||||||
|
parser.add_argument(
|
||||||
|
"-f",
|
||||||
|
"--filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the JSON file to load",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o",
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
default="output.json",
|
||||||
|
help="Path to the output JSON file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-t",
|
||||||
|
"--output_type",
|
||||||
|
choices=[mode.value for mode in OutputType],
|
||||||
|
default=OutputType.FULL.value,
|
||||||
|
help="Type of output to generate",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dataset = Dataset.from_json(args.filename)
|
||||||
|
print(f"Loaded {len(dataset)} entries from {args.filename}")
|
||||||
|
if os.path.exists(args.output):
|
||||||
|
overwrite = prompt_yes_no(
|
||||||
|
f"Output file {args.output} already exists. Do you want to overwrite it?"
|
||||||
|
)
|
||||||
|
if not overwrite:
|
||||||
|
print("Exiting without saving.")
|
||||||
|
exit(0)
|
||||||
|
print(f"Saving dataset to {args.output}...", end=" ", flush=True)
|
||||||
|
dataset.to_json(args.output, OutputType(args.output_type))
|
||||||
|
print("Done")
|
||||||
|
11
utils.py
11
utils.py
@ -164,3 +164,14 @@ def run_git_cmd(cmd: list[str], repo_path: str) -> subprocess.CompletedProcess:
|
|||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_yes_no(prompt: str) -> bool:
|
||||||
|
while True:
|
||||||
|
ans = input(f"{prompt} [y/n]: ").strip().lower()
|
||||||
|
if ans in {"y", "yes"}:
|
||||||
|
return True
|
||||||
|
elif ans in {"n", "no"}:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("Please enter 'y' or 'n'.")
|
||||||
|
Reference in New Issue
Block a user