diff --git a/dataset.py b/dataset.py index bb55456..8496b8b 100644 --- a/dataset.py +++ b/dataset.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field -from typing import Dict, List, Optional -import json +from enum import Enum +from typing import Any, Dict, List, Optional +import json, argparse, os +from utils import prompt_yes_no # fmt: off @dataclass @@ -45,6 +47,42 @@ class DatasetEntry: diffs_after: Dict[str, str] # filename -> diff, changes after the comment +@dataclass +class CommentGenEntry: + id: int + files: Dict[str, str] # filename -> file content + diffs: Dict[str, str] # filename -> diff, diffs between the opening of the PR and the comment + + @staticmethod + def from_entry(entry: DatasetEntry, id: int) -> "CommentGenEntry": + return CommentGenEntry( + id=id, + files={fname: fdata.content_before_pr for fname, fdata in entry.files.items()}, + diffs=entry.diffs_before, + ) + + +@dataclass +class CodeRefinementEntry: + id: int + files: Dict[str, str] # filename -> file content + diffs: Dict[str, str] # filename -> diff, diffs between the opening of the PR and the comment + comments: List[Comment] + + @staticmethod + def from_entry(entry: DatasetEntry, id: int) -> "CodeRefinementEntry": + return CodeRefinementEntry( + id=id, + files={fname: fdata.content_before_pr for fname, fdata in entry.files.items()}, + diffs=entry.diffs_before, + comments=entry.comments, + ) + +class OutputType(Enum): + FULL = "full" + CODE_REFINEMENT = "code_refinement" + COMMENT_GEN = "comment_gen" + # fmt: on @dataclass class Dataset: @@ -53,10 +91,47 @@ class Dataset: def __len__(self) -> int: return sum(1 for entry in self.entries if entry.metadata.successful) - def to_json(self, filename: str): + def to_json(self, filename: str, type_: OutputType = OutputType.FULL) -> None: """Serialize the dataset to a JSON file""" + + entries_to_dump = self.entries + + if type_ == OutputType.COMMENT_GEN: + entries_to_dump = [ + entry + for entry in self.entries + if entry.metadata.selection and entry.metadata.selection.comment_suggests_change + ] + elif type_ == OutputType.CODE_REFINEMENT: + entries_to_dump = [ + entry + for entry in self.entries + if entry.metadata.selection and entry.metadata.selection.diff_after_address_change + ] + + entry_counter = -1 + to_dump = Dataset(entries=entries_to_dump) + + def transform_entry(entry: DatasetEntry | Dataset | Any) -> dict | list: + if not isinstance(entry, (DatasetEntry, Dataset)): + return entry.__dict__ + + if type_ == OutputType.FULL: + return entry.__dict__ + + if isinstance(entry, Dataset): + return entry.entries + + nonlocal entry_counter + entry_counter += 1 + if type_ == OutputType.COMMENT_GEN: + return CommentGenEntry.from_entry(entry, entry_counter).__dict__ + + if type_ == OutputType.CODE_REFINEMENT: + return CodeRefinementEntry.from_entry(entry, entry_counter).__dict__ + with open(filename, "w", encoding="utf-8") as f: - json.dump(self, f, default=lambda o: o.__dict__, indent=4) + json.dump(to_dump, f, default=transform_entry, indent=4) @staticmethod def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset": @@ -93,3 +168,42 @@ class Dataset: entries.append(entry) return Dataset(entries=entries) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Dataset class") + parser.add_argument( + "-f", + "--filename", + type=str, + required=True, + help="Path to the JSON file to load", + ) + parser.add_argument( + "-o", + "--output", + type=str, + default="output.json", + help="Path to the output JSON file", + ) + parser.add_argument( + "-t", + "--output_type", + choices=[mode.value for mode in OutputType], + default=OutputType.FULL.value, + help="Type of output to generate", + ) + args = parser.parse_args() + + dataset = Dataset.from_json(args.filename) + print(f"Loaded {len(dataset)} entries from {args.filename}") + if os.path.exists(args.output): + overwrite = prompt_yes_no( + f"Output file {args.output} already exists. Do you want to overwrite it?" + ) + if not overwrite: + print("Exiting without saving.") + exit(0) + print(f"Saving dataset to {args.output}...", end=" ", flush=True) + dataset.to_json(args.output, OutputType(args.output_type)) + print("Done") diff --git a/utils.py b/utils.py index f9ef6e3..8760d6b 100644 --- a/utils.py +++ b/utils.py @@ -164,3 +164,14 @@ def run_git_cmd(cmd: list[str], repo_path: str) -> subprocess.CompletedProcess: capture_output=True, text=True, ) + + +def prompt_yes_no(prompt: str) -> bool: + while True: + ans = input(f"{prompt} [y/n]: ").strip().lower() + if ans in {"y", "yes"}: + return True + elif ans in {"n", "no"}: + return False + else: + print("Please enter 'y' or 'n'.")