From dd52e4300046b8c8bd758c033eee1f9c42518fb7 Mon Sep 17 00:00:00 2001 From: Karma Riuk Date: Tue, 10 Jun 2025 20:42:38 +0200 Subject: [PATCH] added way to put paraphrases from external csv --- dataset.py | 69 ++++++++++++++++++++++++++++++++++++++-- optional-requirments.txt | 1 + 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/dataset.py b/dataset.py index 3510812..69a07f1 100644 --- a/dataset.py +++ b/dataset.py @@ -1,7 +1,12 @@ from dataclasses import dataclass, field from enum import Enum +import sys, re from typing import Any, Dict, List, Optional, Union import json, argparse, os, uuid +import pandas as pd +from sacrebleu import sentence_bleu as bleu + +from pandas import DataFrame from utils import EnumChoicesAction @@ -60,6 +65,8 @@ class Metadata: return f"{self.id}_{state.value}.tar.gz" return f"{self.repo.replace('/', '_')}_{self.pr_number}_{state.value}.tar.gz" +identical_paraphrase = 0 + @dataclass class DatasetEntry: metadata: Metadata @@ -68,6 +75,27 @@ class DatasetEntry: comments: List[Comment] diffs_after: Dict[str, str] # filename -> diff, changes after the comment + def add_paraphrases(self, paraphrases: list[str]): + global identical_paraphrase + comment = self.comments[0] + for paraphrase in paraphrases: + score = bleu(comment.body, [paraphrase]).score + if paraphrase == comment.body: + identical_paraphrase += 1 + continue + if score > 90: + print(f"OG Comment (id: {self.metadata.id}):") + print(comment.body) + print() + print(f"Paraphrase that is too similar ({score = }):") + print(paraphrase) + add = prompt_yes_no("Do you still want to add this paraphrase to the list?") + if not add: + continue + else: + identical_paraphrase += 1 + comment.paraphrases.append(paraphrase) + @dataclass class CommentGenEntry: @@ -83,7 +111,6 @@ class CommentGenEntry: diffs=entry.diffs_before, ) - @dataclass class CodeRefinementEntry: id: str @@ -172,9 +199,9 @@ class Dataset: @staticmethod def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset": with open(filename, "r", encoding="utf-8") as f: - print(f"Loading dataset from {filename}...", end=" ", flush=True) + print(f"Loading dataset from {filename}...", end=" ", flush=True, file=sys.stderr) data = json.load(f) - print("Done") + print("Done", file=sys.stderr) entries = [] for entry_data in data["entries"]: @@ -217,6 +244,29 @@ class Dataset: ref_map[entry.metadata.id] = entry return ref_map + def add_paraphrases(self, paraphrases_df: DataFrame): + ref_map = self.build_reference_map() + paraphrases_df[["id", "paraphrases"]].apply( + lambda row: process_row(row, ref_map), + axis=1, + ) + + +def sanitize_paraphrases(paraphrases_block: str) -> list[str]: + return [ + re.sub(r'^Paraphrase#\d+: ', '', line).strip() for line in paraphrases_block.splitlines() + ] + + +def process_row(row, ref_map: dict[str, DatasetEntry]): + try: + ref_map[row["id"]].add_paraphrases(sanitize_paraphrases(row["paraphrases"])) + except KeyError: + print( + f"Failed to find id {row['id']} in ref_map", + file=sys.stderr, + ) + if __name__ == "__main__": from utils import prompt_yes_no @@ -234,6 +284,12 @@ if __name__ == "__main__": default="output.json", help="Path to the output JSON file", ) + parser.add_argument( + "-p", + "--paraphrases", + type=str, + help="Path to generated paraphrases. It must be a csv that has the column 'paraphrases'. The content of that column must be a multi-line string, where each line has the form 'Paraphrase#N: '", + ) parser.add_argument( "-t", "--output_type", @@ -250,6 +306,13 @@ if __name__ == "__main__": args = parser.parse_args() dataset = Dataset.from_json(args.filename) + + paraphrases: Optional[DataFrame] = None + if args.paraphrases is not None: + paraphrases = pd.read_csv(args.paraphrases) + dataset.add_paraphrases(paraphrases) + print(f"# identical paraphrases {identical_paraphrase}") + print(f"Loaded {len(dataset.entries)} entries from {args.filename}") if os.path.exists(args.output): overwrite = prompt_yes_no( diff --git a/optional-requirments.txt b/optional-requirments.txt index 817453c..6f77cd5 100644 --- a/optional-requirments.txt +++ b/optional-requirments.txt @@ -1,2 +1,3 @@ requests_cache click +sacrebleu