From 900003bac7ae5d1446ee764528c435ef125ec255 Mon Sep 17 00:00:00 2001 From: Karma Riuk Date: Tue, 27 May 2025 10:48:17 +0200 Subject: [PATCH] added a way to extract the information to then generate paraphrases --- extract_correct_predictions.py | 60 ++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/extract_correct_predictions.py b/extract_correct_predictions.py index 4aec911..d1c6eee 100644 --- a/extract_correct_predictions.py +++ b/extract_correct_predictions.py @@ -1,18 +1,63 @@ +from dataclasses import dataclass +from enum import Enum import os, json, tarfile, argparse -from dataset import Dataset, ArchiveState, OutputType +from typing import Optional +from dataset import Dataset, ArchiveState +from utils import EnumChoicesAction + + +class OutputType(Enum): + CODE_REFINEMENT = "code_refinement" + COMMENT_GEN = "comment_gen" + FOR_PARAPHRASES = "paraphrases" + + +@dataclass +class CommentGenSubmission: + path: str + line_from: int + line_to: Optional[int] + body: str + + +def extract_comment_for_paraphrases(dataset_path: str, output_path: str): + dataset = Dataset.from_json(dataset_path) + results: dict[str, dict] = {} + + for entry in dataset.entries: + sel = entry.metadata.selection + if sel and sel.comment_suggests_change: + comment = entry.comments[0].__dict__ + del comment["paraphrases"] + results[entry.metadata.id] = { + "comment": comment, + "files": {fname: fdata.content_before_pr for fname, fdata in entry.files.items()}, + "diffs_before": entry.diffs_before, + } + + # Write out the exact predictions reference JSON + with open(output_path, "w", encoding="utf-8") as out_file: + json.dump(results, out_file, default=lambda o: o.__dict__, indent=4) + + print(f"Saved {len(results)} entries to {output_path}") def extract_comment_predictions(dataset_path: str, output_path: str): dataset = Dataset.from_json(dataset_path) - results = {} + results: dict[str, CommentGenSubmission] = {} for entry in dataset.entries: sel = entry.metadata.selection if sel and sel.comment_suggests_change: - results[entry.metadata.id] = entry.comments[0].body + results[entry.metadata.id] = CommentGenSubmission( + path=entry.comments[0].file, + line_from=entry.comments[0].from_, + line_to=entry.comments[0].to, + body=entry.comments[0].body, + ) # Write out the exact predictions reference JSON with open(output_path, "w", encoding="utf-8") as out_file: - json.dump(results, out_file, indent=4) + json.dump(results, out_file, default=lambda o: o.__dict__, indent=4) print(f"Saved {len(results)} entries to {output_path}") @@ -82,8 +127,9 @@ if __name__ == "__main__": parser.add_argument( "-t", "--output-type", - choices=[mode.value for mode in OutputType if mode is not OutputType.FULL], - default=OutputType.COMMENT_GEN.value, + type=OutputType, + default=OutputType.COMMENT_GEN, + action=EnumChoicesAction, help="Type of output to generate", ) args = parser.parse_args() @@ -97,3 +143,5 @@ if __name__ == "__main__": elif output_type is OutputType.CODE_REFINEMENT: assert args.archives is not None extract_refinement_predictions(args.dataset, args.archives, args.output) + elif output_type is OutputType.FOR_PARAPHRASES: + extract_comment_for_paraphrases(args.dataset, args.output)