we can now generate the datasets to be served to users

This commit is contained in:
Karma Riuk
2025-04-29 14:41:11 +02:00
parent bde9d45c10
commit b3877733cb
2 changed files with 129 additions and 4 deletions

View File

@ -1,6 +1,8 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional from enum import Enum
import json from typing import Any, Dict, List, Optional
import json, argparse, os
from utils import prompt_yes_no
# fmt: off # fmt: off
@dataclass @dataclass
@ -45,6 +47,42 @@ class DatasetEntry:
diffs_after: Dict[str, str] # filename -> diff, changes after the comment diffs_after: Dict[str, str] # filename -> diff, changes after the comment
@dataclass
class CommentGenEntry:
id: int
files: Dict[str, str] # filename -> file content
diffs: Dict[str, str] # filename -> diff, diffs between the opening of the PR and the comment
@staticmethod
def from_entry(entry: DatasetEntry, id: int) -> "CommentGenEntry":
return CommentGenEntry(
id=id,
files={fname: fdata.content_before_pr for fname, fdata in entry.files.items()},
diffs=entry.diffs_before,
)
@dataclass
class CodeRefinementEntry:
id: int
files: Dict[str, str] # filename -> file content
diffs: Dict[str, str] # filename -> diff, diffs between the opening of the PR and the comment
comments: List[Comment]
@staticmethod
def from_entry(entry: DatasetEntry, id: int) -> "CodeRefinementEntry":
return CodeRefinementEntry(
id=id,
files={fname: fdata.content_before_pr for fname, fdata in entry.files.items()},
diffs=entry.diffs_before,
comments=entry.comments,
)
class OutputType(Enum):
FULL = "full"
CODE_REFINEMENT = "code_refinement"
COMMENT_GEN = "comment_gen"
# fmt: on # fmt: on
@dataclass @dataclass
class Dataset: class Dataset:
@ -53,10 +91,47 @@ class Dataset:
def __len__(self) -> int: def __len__(self) -> int:
return sum(1 for entry in self.entries if entry.metadata.successful) return sum(1 for entry in self.entries if entry.metadata.successful)
def to_json(self, filename: str): def to_json(self, filename: str, type_: OutputType = OutputType.FULL) -> None:
"""Serialize the dataset to a JSON file""" """Serialize the dataset to a JSON file"""
entries_to_dump = self.entries
if type_ == OutputType.COMMENT_GEN:
entries_to_dump = [
entry
for entry in self.entries
if entry.metadata.selection and entry.metadata.selection.comment_suggests_change
]
elif type_ == OutputType.CODE_REFINEMENT:
entries_to_dump = [
entry
for entry in self.entries
if entry.metadata.selection and entry.metadata.selection.diff_after_address_change
]
entry_counter = -1
to_dump = Dataset(entries=entries_to_dump)
def transform_entry(entry: DatasetEntry | Dataset | Any) -> dict | list:
if not isinstance(entry, (DatasetEntry, Dataset)):
return entry.__dict__
if type_ == OutputType.FULL:
return entry.__dict__
if isinstance(entry, Dataset):
return entry.entries
nonlocal entry_counter
entry_counter += 1
if type_ == OutputType.COMMENT_GEN:
return CommentGenEntry.from_entry(entry, entry_counter).__dict__
if type_ == OutputType.CODE_REFINEMENT:
return CodeRefinementEntry.from_entry(entry, entry_counter).__dict__
with open(filename, "w", encoding="utf-8") as f: with open(filename, "w", encoding="utf-8") as f:
json.dump(self, f, default=lambda o: o.__dict__, indent=4) json.dump(to_dump, f, default=transform_entry, indent=4)
@staticmethod @staticmethod
def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset": def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset":
@ -93,3 +168,42 @@ class Dataset:
entries.append(entry) entries.append(entry)
return Dataset(entries=entries) return Dataset(entries=entries)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Dataset class")
parser.add_argument(
"-f",
"--filename",
type=str,
required=True,
help="Path to the JSON file to load",
)
parser.add_argument(
"-o",
"--output",
type=str,
default="output.json",
help="Path to the output JSON file",
)
parser.add_argument(
"-t",
"--output_type",
choices=[mode.value for mode in OutputType],
default=OutputType.FULL.value,
help="Type of output to generate",
)
args = parser.parse_args()
dataset = Dataset.from_json(args.filename)
print(f"Loaded {len(dataset)} entries from {args.filename}")
if os.path.exists(args.output):
overwrite = prompt_yes_no(
f"Output file {args.output} already exists. Do you want to overwrite it?"
)
if not overwrite:
print("Exiting without saving.")
exit(0)
print(f"Saving dataset to {args.output}...", end=" ", flush=True)
dataset.to_json(args.output, OutputType(args.output_type))
print("Done")

View File

@ -164,3 +164,14 @@ def run_git_cmd(cmd: list[str], repo_path: str) -> subprocess.CompletedProcess:
capture_output=True, capture_output=True,
text=True, text=True,
) )
def prompt_yes_no(prompt: str) -> bool:
while True:
ans = input(f"{prompt} [y/n]: ").strip().lower()
if ans in {"y", "yes"}:
return True
elif ans in {"n", "no"}:
return False
else:
print("Please enter 'y' or 'n'.")