added way to put paraphrases from external csv

This commit is contained in:
Karma Riuk
2025-06-10 20:42:38 +02:00
parent 1754f93018
commit dd52e43000
2 changed files with 67 additions and 3 deletions

View File

@ -1,7 +1,12 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
import sys, re
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import json, argparse, os, uuid import json, argparse, os, uuid
import pandas as pd
from sacrebleu import sentence_bleu as bleu
from pandas import DataFrame
from utils import EnumChoicesAction from utils import EnumChoicesAction
@ -60,6 +65,8 @@ class Metadata:
return f"{self.id}_{state.value}.tar.gz" return f"{self.id}_{state.value}.tar.gz"
return f"{self.repo.replace('/', '_')}_{self.pr_number}_{state.value}.tar.gz" return f"{self.repo.replace('/', '_')}_{self.pr_number}_{state.value}.tar.gz"
identical_paraphrase = 0
@dataclass @dataclass
class DatasetEntry: class DatasetEntry:
metadata: Metadata metadata: Metadata
@ -68,6 +75,27 @@ class DatasetEntry:
comments: List[Comment] comments: List[Comment]
diffs_after: Dict[str, str] # filename -> diff, changes after the comment diffs_after: Dict[str, str] # filename -> diff, changes after the comment
def add_paraphrases(self, paraphrases: list[str]):
global identical_paraphrase
comment = self.comments[0]
for paraphrase in paraphrases:
score = bleu(comment.body, [paraphrase]).score
if paraphrase == comment.body:
identical_paraphrase += 1
continue
if score > 90:
print(f"OG Comment (id: {self.metadata.id}):")
print(comment.body)
print()
print(f"Paraphrase that is too similar ({score = }):")
print(paraphrase)
add = prompt_yes_no("Do you still want to add this paraphrase to the list?")
if not add:
continue
else:
identical_paraphrase += 1
comment.paraphrases.append(paraphrase)
@dataclass @dataclass
class CommentGenEntry: class CommentGenEntry:
@ -83,7 +111,6 @@ class CommentGenEntry:
diffs=entry.diffs_before, diffs=entry.diffs_before,
) )
@dataclass @dataclass
class CodeRefinementEntry: class CodeRefinementEntry:
id: str id: str
@ -172,9 +199,9 @@ class Dataset:
@staticmethod @staticmethod
def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset": def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset":
with open(filename, "r", encoding="utf-8") as f: with open(filename, "r", encoding="utf-8") as f:
print(f"Loading dataset from {filename}...", end=" ", flush=True) print(f"Loading dataset from {filename}...", end=" ", flush=True, file=sys.stderr)
data = json.load(f) data = json.load(f)
print("Done") print("Done", file=sys.stderr)
entries = [] entries = []
for entry_data in data["entries"]: for entry_data in data["entries"]:
@ -217,6 +244,29 @@ class Dataset:
ref_map[entry.metadata.id] = entry ref_map[entry.metadata.id] = entry
return ref_map return ref_map
def add_paraphrases(self, paraphrases_df: DataFrame):
ref_map = self.build_reference_map()
paraphrases_df[["id", "paraphrases"]].apply(
lambda row: process_row(row, ref_map),
axis=1,
)
def sanitize_paraphrases(paraphrases_block: str) -> list[str]:
return [
re.sub(r'^Paraphrase#\d+: ', '', line).strip() for line in paraphrases_block.splitlines()
]
def process_row(row, ref_map: dict[str, DatasetEntry]):
try:
ref_map[row["id"]].add_paraphrases(sanitize_paraphrases(row["paraphrases"]))
except KeyError:
print(
f"Failed to find id {row['id']} in ref_map",
file=sys.stderr,
)
if __name__ == "__main__": if __name__ == "__main__":
from utils import prompt_yes_no from utils import prompt_yes_no
@ -234,6 +284,12 @@ if __name__ == "__main__":
default="output.json", default="output.json",
help="Path to the output JSON file", help="Path to the output JSON file",
) )
parser.add_argument(
"-p",
"--paraphrases",
type=str,
help="Path to generated paraphrases. It must be a csv that has the column 'paraphrases'. The content of that column must be a multi-line string, where each line has the form 'Paraphrase#N: <comment_paraphrase>'",
)
parser.add_argument( parser.add_argument(
"-t", "-t",
"--output_type", "--output_type",
@ -250,6 +306,13 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
dataset = Dataset.from_json(args.filename) dataset = Dataset.from_json(args.filename)
paraphrases: Optional[DataFrame] = None
if args.paraphrases is not None:
paraphrases = pd.read_csv(args.paraphrases)
dataset.add_paraphrases(paraphrases)
print(f"# identical paraphrases {identical_paraphrase}")
print(f"Loaded {len(dataset.entries)} entries from {args.filename}") print(f"Loaded {len(dataset.entries)} entries from {args.filename}")
if os.path.exists(args.output): if os.path.exists(args.output):
overwrite = prompt_yes_no( overwrite = prompt_yes_no(

View File

@ -1,2 +1,3 @@
requests_cache requests_cache
click click
sacrebleu