mirror of
https://github.com/karma-riuk/crab.git
synced 2025-07-04 21:28:12 +02:00
added way to put paraphrases from external csv
This commit is contained in:
69
dataset.py
69
dataset.py
@ -1,7 +1,12 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import sys, re
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import json, argparse, os, uuid
|
||||
import pandas as pd
|
||||
from sacrebleu import sentence_bleu as bleu
|
||||
|
||||
from pandas import DataFrame
|
||||
|
||||
from utils import EnumChoicesAction
|
||||
|
||||
@ -60,6 +65,8 @@ class Metadata:
|
||||
return f"{self.id}_{state.value}.tar.gz"
|
||||
return f"{self.repo.replace('/', '_')}_{self.pr_number}_{state.value}.tar.gz"
|
||||
|
||||
identical_paraphrase = 0
|
||||
|
||||
@dataclass
|
||||
class DatasetEntry:
|
||||
metadata: Metadata
|
||||
@ -68,6 +75,27 @@ class DatasetEntry:
|
||||
comments: List[Comment]
|
||||
diffs_after: Dict[str, str] # filename -> diff, changes after the comment
|
||||
|
||||
def add_paraphrases(self, paraphrases: list[str]):
|
||||
global identical_paraphrase
|
||||
comment = self.comments[0]
|
||||
for paraphrase in paraphrases:
|
||||
score = bleu(comment.body, [paraphrase]).score
|
||||
if paraphrase == comment.body:
|
||||
identical_paraphrase += 1
|
||||
continue
|
||||
if score > 90:
|
||||
print(f"OG Comment (id: {self.metadata.id}):")
|
||||
print(comment.body)
|
||||
print()
|
||||
print(f"Paraphrase that is too similar ({score = }):")
|
||||
print(paraphrase)
|
||||
add = prompt_yes_no("Do you still want to add this paraphrase to the list?")
|
||||
if not add:
|
||||
continue
|
||||
else:
|
||||
identical_paraphrase += 1
|
||||
comment.paraphrases.append(paraphrase)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommentGenEntry:
|
||||
@ -83,7 +111,6 @@ class CommentGenEntry:
|
||||
diffs=entry.diffs_before,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeRefinementEntry:
|
||||
id: str
|
||||
@ -172,9 +199,9 @@ class Dataset:
|
||||
@staticmethod
|
||||
def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset":
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
print(f"Loading dataset from {filename}...", end=" ", flush=True)
|
||||
print(f"Loading dataset from {filename}...", end=" ", flush=True, file=sys.stderr)
|
||||
data = json.load(f)
|
||||
print("Done")
|
||||
print("Done", file=sys.stderr)
|
||||
|
||||
entries = []
|
||||
for entry_data in data["entries"]:
|
||||
@ -217,6 +244,29 @@ class Dataset:
|
||||
ref_map[entry.metadata.id] = entry
|
||||
return ref_map
|
||||
|
||||
def add_paraphrases(self, paraphrases_df: DataFrame):
|
||||
ref_map = self.build_reference_map()
|
||||
paraphrases_df[["id", "paraphrases"]].apply(
|
||||
lambda row: process_row(row, ref_map),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
|
||||
def sanitize_paraphrases(paraphrases_block: str) -> list[str]:
|
||||
return [
|
||||
re.sub(r'^Paraphrase#\d+: ', '', line).strip() for line in paraphrases_block.splitlines()
|
||||
]
|
||||
|
||||
|
||||
def process_row(row, ref_map: dict[str, DatasetEntry]):
|
||||
try:
|
||||
ref_map[row["id"]].add_paraphrases(sanitize_paraphrases(row["paraphrases"]))
|
||||
except KeyError:
|
||||
print(
|
||||
f"Failed to find id {row['id']} in ref_map",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from utils import prompt_yes_no
|
||||
@ -234,6 +284,12 @@ if __name__ == "__main__":
|
||||
default="output.json",
|
||||
help="Path to the output JSON file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--paraphrases",
|
||||
type=str,
|
||||
help="Path to generated paraphrases. It must be a csv that has the column 'paraphrases'. The content of that column must be a multi-line string, where each line has the form 'Paraphrase#N: <comment_paraphrase>'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--output_type",
|
||||
@ -250,6 +306,13 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = Dataset.from_json(args.filename)
|
||||
|
||||
paraphrases: Optional[DataFrame] = None
|
||||
if args.paraphrases is not None:
|
||||
paraphrases = pd.read_csv(args.paraphrases)
|
||||
dataset.add_paraphrases(paraphrases)
|
||||
print(f"# identical paraphrases {identical_paraphrase}")
|
||||
|
||||
print(f"Loaded {len(dataset.entries)} entries from {args.filename}")
|
||||
if os.path.exists(args.output):
|
||||
overwrite = prompt_yes_no(
|
||||
|
@ -1,2 +1,3 @@
|
||||
requests_cache
|
||||
click
|
||||
sacrebleu
|
||||
|
Reference in New Issue
Block a user