diff --git a/dataset.py b/dataset.py index 8496b8b..efffe09 100644 --- a/dataset.py +++ b/dataset.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional -import json, argparse, os +import json, argparse, os, uuid from utils import prompt_yes_no # fmt: off @@ -27,6 +27,7 @@ class Selection: @dataclass class Metadata: + id: str repo: str # the name of the repo, with style XXX/YYY pr_number: int pr_title: str @@ -49,14 +50,14 @@ class DatasetEntry: @dataclass class CommentGenEntry: - id: int + id: str files: Dict[str, str] # filename -> file content diffs: Dict[str, str] # filename -> diff, diffs between the opening of the PR and the comment @staticmethod - def from_entry(entry: DatasetEntry, id: int) -> "CommentGenEntry": + def from_entry(entry: DatasetEntry) -> "CommentGenEntry": return CommentGenEntry( - id=id, + id=entry.metadata.id, files={fname: fdata.content_before_pr for fname, fdata in entry.files.items()}, diffs=entry.diffs_before, ) @@ -64,15 +65,15 @@ class CommentGenEntry: @dataclass class CodeRefinementEntry: - id: int + id: str files: Dict[str, str] # filename -> file content diffs: Dict[str, str] # filename -> diff, diffs between the opening of the PR and the comment comments: List[Comment] @staticmethod - def from_entry(entry: DatasetEntry, id: int) -> "CodeRefinementEntry": + def from_entry(entry: DatasetEntry) -> "CodeRefinementEntry": return CodeRefinementEntry( - id=id, + id=entry.metadata.id, files={fname: fdata.content_before_pr for fname, fdata in entry.files.items()}, diffs=entry.diffs_before, comments=entry.comments, @@ -109,7 +110,6 @@ class Dataset: if entry.metadata.selection and entry.metadata.selection.diff_after_address_change ] - entry_counter = -1 to_dump = Dataset(entries=entries_to_dump) def transform_entry(entry: DatasetEntry | Dataset | Any) -> dict | list: @@ -122,13 +122,11 @@ class Dataset: if isinstance(entry, Dataset): return entry.entries - nonlocal entry_counter - entry_counter += 1 if type_ == OutputType.COMMENT_GEN: - return CommentGenEntry.from_entry(entry, entry_counter).__dict__ + return CommentGenEntry.from_entry(entry).__dict__ if type_ == OutputType.CODE_REFINEMENT: - return CodeRefinementEntry.from_entry(entry, entry_counter).__dict__ + return CodeRefinementEntry.from_entry(entry).__dict__ with open(filename, "w", encoding="utf-8") as f: json.dump(to_dump, f, default=transform_entry, indent=4) @@ -146,6 +144,8 @@ class Dataset: selection_data = metadata_data["selection"] if "selection" in metadata_data else None selection = Selection(**selection_data) if selection_data else None metadata_data["selection"] = selection + if "id" not in metadata_data: + metadata_data["id"] = uuid.uuid4().hex metadata = Metadata(**metadata_data) if (