added a way to extract the information to then

generate paraphrases
This commit is contained in:
Karma Riuk
2025-05-27 10:48:17 +02:00
parent 63b69e40b8
commit 900003bac7

View File

@ -1,18 +1,63 @@
from dataclasses import dataclass
from enum import Enum
import os, json, tarfile, argparse
from dataset import Dataset, ArchiveState, OutputType
from typing import Optional
from dataset import Dataset, ArchiveState
from utils import EnumChoicesAction
class OutputType(Enum):
CODE_REFINEMENT = "code_refinement"
COMMENT_GEN = "comment_gen"
FOR_PARAPHRASES = "paraphrases"
@dataclass
class CommentGenSubmission:
path: str
line_from: int
line_to: Optional[int]
body: str
def extract_comment_for_paraphrases(dataset_path: str, output_path: str):
dataset = Dataset.from_json(dataset_path)
results: dict[str, dict] = {}
for entry in dataset.entries:
sel = entry.metadata.selection
if sel and sel.comment_suggests_change:
comment = entry.comments[0].__dict__
del comment["paraphrases"]
results[entry.metadata.id] = {
"comment": comment,
"files": {fname: fdata.content_before_pr for fname, fdata in entry.files.items()},
"diffs_before": entry.diffs_before,
}
# Write out the exact predictions reference JSON
with open(output_path, "w", encoding="utf-8") as out_file:
json.dump(results, out_file, default=lambda o: o.__dict__, indent=4)
print(f"Saved {len(results)} entries to {output_path}")
def extract_comment_predictions(dataset_path: str, output_path: str):
dataset = Dataset.from_json(dataset_path)
results = {}
results: dict[str, CommentGenSubmission] = {}
for entry in dataset.entries:
sel = entry.metadata.selection
if sel and sel.comment_suggests_change:
results[entry.metadata.id] = entry.comments[0].body
results[entry.metadata.id] = CommentGenSubmission(
path=entry.comments[0].file,
line_from=entry.comments[0].from_,
line_to=entry.comments[0].to,
body=entry.comments[0].body,
)
# Write out the exact predictions reference JSON
with open(output_path, "w", encoding="utf-8") as out_file:
json.dump(results, out_file, indent=4)
json.dump(results, out_file, default=lambda o: o.__dict__, indent=4)
print(f"Saved {len(results)} entries to {output_path}")
@ -82,8 +127,9 @@ if __name__ == "__main__":
parser.add_argument(
"-t",
"--output-type",
choices=[mode.value for mode in OutputType if mode is not OutputType.FULL],
default=OutputType.COMMENT_GEN.value,
type=OutputType,
default=OutputType.COMMENT_GEN,
action=EnumChoicesAction,
help="Type of output to generate",
)
args = parser.parse_args()
@ -97,3 +143,5 @@ if __name__ == "__main__":
elif output_type is OutputType.CODE_REFINEMENT:
assert args.archives is not None
extract_refinement_predictions(args.dataset, args.archives, args.output)
elif output_type is OutputType.FOR_PARAPHRASES:
extract_comment_for_paraphrases(args.dataset, args.output)