added way to put paraphrases from external csv

This commit is contained in:
Karma Riuk
2025-06-10 20:42:38 +02:00
parent 1754f93018
commit dd52e43000
2 changed files with 67 additions and 3 deletions

View File

@ -1,7 +1,12 @@
from dataclasses import dataclass, field
from enum import Enum
import sys, re
from typing import Any, Dict, List, Optional, Union
import json, argparse, os, uuid
import pandas as pd
from sacrebleu import sentence_bleu as bleu
from pandas import DataFrame
from utils import EnumChoicesAction
@ -60,6 +65,8 @@ class Metadata:
return f"{self.id}_{state.value}.tar.gz"
return f"{self.repo.replace('/', '_')}_{self.pr_number}_{state.value}.tar.gz"
identical_paraphrase = 0
@dataclass
class DatasetEntry:
metadata: Metadata
@ -68,6 +75,27 @@ class DatasetEntry:
comments: List[Comment]
diffs_after: Dict[str, str] # filename -> diff, changes after the comment
def add_paraphrases(self, paraphrases: list[str]):
global identical_paraphrase
comment = self.comments[0]
for paraphrase in paraphrases:
score = bleu(comment.body, [paraphrase]).score
if paraphrase == comment.body:
identical_paraphrase += 1
continue
if score > 90:
print(f"OG Comment (id: {self.metadata.id}):")
print(comment.body)
print()
print(f"Paraphrase that is too similar ({score = }):")
print(paraphrase)
add = prompt_yes_no("Do you still want to add this paraphrase to the list?")
if not add:
continue
else:
identical_paraphrase += 1
comment.paraphrases.append(paraphrase)
@dataclass
class CommentGenEntry:
@ -83,7 +111,6 @@ class CommentGenEntry:
diffs=entry.diffs_before,
)
@dataclass
class CodeRefinementEntry:
id: str
@ -172,9 +199,9 @@ class Dataset:
@staticmethod
def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset":
with open(filename, "r", encoding="utf-8") as f:
print(f"Loading dataset from {filename}...", end=" ", flush=True)
print(f"Loading dataset from {filename}...", end=" ", flush=True, file=sys.stderr)
data = json.load(f)
print("Done")
print("Done", file=sys.stderr)
entries = []
for entry_data in data["entries"]:
@ -217,6 +244,29 @@ class Dataset:
ref_map[entry.metadata.id] = entry
return ref_map
def add_paraphrases(self, paraphrases_df: DataFrame):
ref_map = self.build_reference_map()
paraphrases_df[["id", "paraphrases"]].apply(
lambda row: process_row(row, ref_map),
axis=1,
)
def sanitize_paraphrases(paraphrases_block: str) -> list[str]:
return [
re.sub(r'^Paraphrase#\d+: ', '', line).strip() for line in paraphrases_block.splitlines()
]
def process_row(row, ref_map: dict[str, DatasetEntry]):
try:
ref_map[row["id"]].add_paraphrases(sanitize_paraphrases(row["paraphrases"]))
except KeyError:
print(
f"Failed to find id {row['id']} in ref_map",
file=sys.stderr,
)
if __name__ == "__main__":
from utils import prompt_yes_no
@ -234,6 +284,12 @@ if __name__ == "__main__":
default="output.json",
help="Path to the output JSON file",
)
parser.add_argument(
"-p",
"--paraphrases",
type=str,
help="Path to generated paraphrases. It must be a csv that has the column 'paraphrases'. The content of that column must be a multi-line string, where each line has the form 'Paraphrase#N: <comment_paraphrase>'",
)
parser.add_argument(
"-t",
"--output_type",
@ -250,6 +306,13 @@ if __name__ == "__main__":
args = parser.parse_args()
dataset = Dataset.from_json(args.filename)
paraphrases: Optional[DataFrame] = None
if args.paraphrases is not None:
paraphrases = pd.read_csv(args.paraphrases)
dataset.add_paraphrases(paraphrases)
print(f"# identical paraphrases {identical_paraphrase}")
print(f"Loaded {len(dataset.entries)} entries from {args.filename}")
if os.path.exists(args.output):
overwrite = prompt_yes_no(

View File

@ -1,2 +1,3 @@
requests_cache
click
sacrebleu