mirror of
https://github.com/karma-riuk/crab.git
synced 2025-07-05 05:28:13 +02:00
added way to put paraphrases from external csv
This commit is contained in:
69
dataset.py
69
dataset.py
@ -1,7 +1,12 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import sys, re
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
import json, argparse, os, uuid
|
import json, argparse, os, uuid
|
||||||
|
import pandas as pd
|
||||||
|
from sacrebleu import sentence_bleu as bleu
|
||||||
|
|
||||||
|
from pandas import DataFrame
|
||||||
|
|
||||||
from utils import EnumChoicesAction
|
from utils import EnumChoicesAction
|
||||||
|
|
||||||
@ -60,6 +65,8 @@ class Metadata:
|
|||||||
return f"{self.id}_{state.value}.tar.gz"
|
return f"{self.id}_{state.value}.tar.gz"
|
||||||
return f"{self.repo.replace('/', '_')}_{self.pr_number}_{state.value}.tar.gz"
|
return f"{self.repo.replace('/', '_')}_{self.pr_number}_{state.value}.tar.gz"
|
||||||
|
|
||||||
|
identical_paraphrase = 0
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DatasetEntry:
|
class DatasetEntry:
|
||||||
metadata: Metadata
|
metadata: Metadata
|
||||||
@ -68,6 +75,27 @@ class DatasetEntry:
|
|||||||
comments: List[Comment]
|
comments: List[Comment]
|
||||||
diffs_after: Dict[str, str] # filename -> diff, changes after the comment
|
diffs_after: Dict[str, str] # filename -> diff, changes after the comment
|
||||||
|
|
||||||
|
def add_paraphrases(self, paraphrases: list[str]):
|
||||||
|
global identical_paraphrase
|
||||||
|
comment = self.comments[0]
|
||||||
|
for paraphrase in paraphrases:
|
||||||
|
score = bleu(comment.body, [paraphrase]).score
|
||||||
|
if paraphrase == comment.body:
|
||||||
|
identical_paraphrase += 1
|
||||||
|
continue
|
||||||
|
if score > 90:
|
||||||
|
print(f"OG Comment (id: {self.metadata.id}):")
|
||||||
|
print(comment.body)
|
||||||
|
print()
|
||||||
|
print(f"Paraphrase that is too similar ({score = }):")
|
||||||
|
print(paraphrase)
|
||||||
|
add = prompt_yes_no("Do you still want to add this paraphrase to the list?")
|
||||||
|
if not add:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
identical_paraphrase += 1
|
||||||
|
comment.paraphrases.append(paraphrase)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CommentGenEntry:
|
class CommentGenEntry:
|
||||||
@ -83,7 +111,6 @@ class CommentGenEntry:
|
|||||||
diffs=entry.diffs_before,
|
diffs=entry.diffs_before,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CodeRefinementEntry:
|
class CodeRefinementEntry:
|
||||||
id: str
|
id: str
|
||||||
@ -172,9 +199,9 @@ class Dataset:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset":
|
def from_json(filename: str, keep_still_in_progress: bool = False) -> "Dataset":
|
||||||
with open(filename, "r", encoding="utf-8") as f:
|
with open(filename, "r", encoding="utf-8") as f:
|
||||||
print(f"Loading dataset from {filename}...", end=" ", flush=True)
|
print(f"Loading dataset from {filename}...", end=" ", flush=True, file=sys.stderr)
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
print("Done")
|
print("Done", file=sys.stderr)
|
||||||
|
|
||||||
entries = []
|
entries = []
|
||||||
for entry_data in data["entries"]:
|
for entry_data in data["entries"]:
|
||||||
@ -217,6 +244,29 @@ class Dataset:
|
|||||||
ref_map[entry.metadata.id] = entry
|
ref_map[entry.metadata.id] = entry
|
||||||
return ref_map
|
return ref_map
|
||||||
|
|
||||||
|
def add_paraphrases(self, paraphrases_df: DataFrame):
|
||||||
|
ref_map = self.build_reference_map()
|
||||||
|
paraphrases_df[["id", "paraphrases"]].apply(
|
||||||
|
lambda row: process_row(row, ref_map),
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_paraphrases(paraphrases_block: str) -> list[str]:
|
||||||
|
return [
|
||||||
|
re.sub(r'^Paraphrase#\d+: ', '', line).strip() for line in paraphrases_block.splitlines()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def process_row(row, ref_map: dict[str, DatasetEntry]):
|
||||||
|
try:
|
||||||
|
ref_map[row["id"]].add_paraphrases(sanitize_paraphrases(row["paraphrases"]))
|
||||||
|
except KeyError:
|
||||||
|
print(
|
||||||
|
f"Failed to find id {row['id']} in ref_map",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from utils import prompt_yes_no
|
from utils import prompt_yes_no
|
||||||
@ -234,6 +284,12 @@ if __name__ == "__main__":
|
|||||||
default="output.json",
|
default="output.json",
|
||||||
help="Path to the output JSON file",
|
help="Path to the output JSON file",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--paraphrases",
|
||||||
|
type=str,
|
||||||
|
help="Path to generated paraphrases. It must be a csv that has the column 'paraphrases'. The content of that column must be a multi-line string, where each line has the form 'Paraphrase#N: <comment_paraphrase>'",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-t",
|
"-t",
|
||||||
"--output_type",
|
"--output_type",
|
||||||
@ -250,6 +306,13 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
dataset = Dataset.from_json(args.filename)
|
dataset = Dataset.from_json(args.filename)
|
||||||
|
|
||||||
|
paraphrases: Optional[DataFrame] = None
|
||||||
|
if args.paraphrases is not None:
|
||||||
|
paraphrases = pd.read_csv(args.paraphrases)
|
||||||
|
dataset.add_paraphrases(paraphrases)
|
||||||
|
print(f"# identical paraphrases {identical_paraphrase}")
|
||||||
|
|
||||||
print(f"Loaded {len(dataset.entries)} entries from {args.filename}")
|
print(f"Loaded {len(dataset.entries)} entries from {args.filename}")
|
||||||
if os.path.exists(args.output):
|
if os.path.exists(args.output):
|
||||||
overwrite = prompt_yes_no(
|
overwrite = prompt_yes_no(
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
requests_cache
|
requests_cache
|
||||||
click
|
click
|
||||||
|
sacrebleu
|
||||||
|
Reference in New Issue
Block a user