ported backend to python

This commit is contained in:
Karma Riuk
2025-05-13 13:27:38 +02:00
parent e5bd1d3a08
commit 3a4bfd611b
22 changed files with 330 additions and 658 deletions

View File

@ -1,125 +0,0 @@
/*
* Calculates BLEU score between a reference and candidate sentence.
* Reference and candidate should be token arrays (e.g. split by whitespace).
* We compute modified n-gram precisions for n=1..4, geometric mean, with smoothing (optional).
* We include the brevity penalty.
*/
/**
* Extracts n-grams from a sequence of tokens.
* @param {string[]} tokens - Array of tokens.
* @param {number} n - Size of the n-gram.
* @returns {Object} Map from n-gram string to its count.
*/
function getNGramCounts(tokens, n) {
const counts = Object.create(null);
for (let i = 0; i + n <= tokens.length; i++) {
const gram = tokens.slice(i, i + n).join(" ");
counts[gram] = (counts[gram] || 0) + 1;
}
return counts;
}
/**
* Computes modified precision for a given n.
* @param {string[]} reference - Reference token array.
* @param {string[]} candidate - Candidate token array.
* @param {number} n - n-gram order.
* @returns {number} Modified n-gram precision.
*/
function modifiedPrecision(reference, candidate, n) {
const refCounts = getNGramCounts(reference, n);
const candCounts = getNGramCounts(candidate, n);
let matchCount = 0;
let totalCount = 0;
for (const gram in candCounts) {
const countCand = candCounts[gram];
const countRef = refCounts[gram] || 0;
matchCount += Math.min(countCand, countRef);
totalCount += countCand;
}
// Avoid division by zero
return totalCount === 0 ? 0 : matchCount / totalCount;
}
/**
* Computes brevity penalty.
* @param {number} refLength - Length of reference sentence.
* @param {number} candLength - Length of candidate sentence.
* @returns {number} Brevity penalty.
*/
function brevityPenalty(refLength, candLength) {
if (candLength > refLength) {
return 1;
}
if (candLength === 0) {
return 0;
}
return Math.exp(1 - refLength / candLength);
}
/**
* Computes BLEU score.
* @param {string} refSentence - Reference sentence.
* @param {string} candSentence - Candidate sentence.
* @param {number} maxN - Maximum n-gram order (default=4).
* @param {boolean} smooth - Whether to apply smoothing (default=false).
* @returns {number} BLEU score between 0 and 1.
*/
export function bleu(refSentence, candSentence, maxN = 4, smooth = false) {
const reference = refSentence.trim().split(/\s+/);
const candidate = candSentence.trim().split(/\s+/);
const refLen = reference.length;
const candLen = candidate.length;
// count how many times we've hit a zero count so far
const precisions = [];
for (let n = 1; n <= maxN; n++) {
let p = modifiedPrecision(reference, candidate, n);
if (p === 0 && smooth) {
p = 1 / Math.pow(candLen, n);
}
precisions.push(p);
}
// Compute geometric mean of precisions
// if any precision is zero (and no smoothing), BLEU=0
if (precisions.some((p) => p === 0)) {
return 0;
}
const logPrecisionSum =
precisions.map((p) => Math.log(p)).reduce((a, b) => a + b, 0) / maxN;
const geoMean = Math.exp(logPrecisionSum);
const bp = brevityPenalty(refLen, candLen);
return bp * geoMean;
}
// if __name__ == "__main__"
if (process.argv[1] === import.meta.filename) {
const test_pairs = [
["the cat is on the mat", "the cat is on the mat"],
["the cat is on the mat", "the the the the the the the"],
["the cat is on the mat", "the cat on the mat"],
["the cat is on the mat", "the cat is on the"],
["the cat is on the mat", "foo bar baz qux"],
[
"The quick brown fox jumps over the lazy dog",
"The quick brown dog jumps over the lazy fox",
],
[
"This could be `static` to prevent any funkiness, i.e. attempting to use class state during the constructor or similar.",
"This could be `static` to prevent any funkiness, i.e. attempting to use class state during the constructor or similar.",
],
];
for (const [reference, candidate] of test_pairs) {
const score = bleu(reference, candidate, 4);
console.log(`reference: ${reference}`);
console.log(`candidate: ${candidate}`);
console.log(`BLEU score: ${score.toFixed(4)}`);
}
}

View File

@ -1,6 +0,0 @@
export class InvalidJsonFormatError extends Error {
constructor(message = 'JSON must be an object mapping strings to strings') {
super(message);
this.name = 'InvalidJsonFormatError';
}
}

4
src/utils/errors.py Normal file
View File

@ -0,0 +1,4 @@
class InvalidJsonFormatError(Exception):
def __init__(self, message='JSON must be an object mapping strings to strings'):
super().__init__(message)
self.name = 'InvalidJsonFormatError'

View File

@ -1,11 +0,0 @@
import { fileURLToPath } from 'url';
import { dirname, join } from 'path';
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
// Get the project root directory (2 levels up from src/utils)
export const PROJECT_ROOT = join(__dirname, '../..');
// Helper function to create paths relative to project root
export const getProjectPath = (relativePath) => join(PROJECT_ROOT, relativePath);

8
src/utils/paths.py Normal file
View File

@ -0,0 +1,8 @@
# utils/paths.py
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
def get_project_path(relative_path: str) -> Path:
return PROJECT_ROOT / relative_path

View File

@ -1,58 +0,0 @@
import fs from "fs";
import { getProjectPath } from "../utils/paths.js";
import { bleu } from "../utils/bleu.js";
function buildReferenceMap(dataset_path) {
const referenceMap = {};
const dataset = JSON.parse(fs.readFileSync(dataset_path));
for (const entry of dataset.entries) {
const id = entry.metadata.id;
const comments = entry.comments;
referenceMap[id] = comments.map((c) => c.body);
}
return referenceMap;
}
const REFERENCE_MAP = buildReferenceMap(getProjectPath("data/dataset.json"));
export const evaluate_comments = (answers, percent_cb) => {
const total = Object.keys(answers).length;
let i = 0;
const results = {};
for (const [id, generated_comment] of Object.entries(answers)) {
const n_tokens_generated = generated_comment.trim().split(/\s+/).length;
if (!(id in REFERENCE_MAP)) {
console.error(`id: "${id}" is not present in the dataset`);
continue;
}
const paraphrases = REFERENCE_MAP[id];
let maxScore = 0;
const scores = [];
for (const paraphrase of paraphrases) {
const n_tokens_paraphrase = paraphrase.trim().split(/\s+/).length;
const max_n = Math.min(n_tokens_generated, n_tokens_paraphrase, 4);
const score = bleu(paraphrase, generated_comment, max_n);
scores.push(score);
maxScore = Math.max(score, maxScore);
}
results[id] = {
max_bleu_score: maxScore,
bleu_scores: scores,
proposed_comment: generated_comment,
};
percent_cb(Math.floor((++i / total) * 100));
}
return results;
};
export const evaluate_refinement = (answers, percent_cb) => {
const total = Object.keys(answers).length;
let i = 0;
for (const [key, value] of Object.entries(answers)) {
console.log(`Processing ${key}: ${value}...`);
// await new Promise((res) => setTimeout(res, 1000));
console.log("Done");
percent_cb(Math.floor((++i / total) * 100));
}
};

46
src/utils/process_data.py Normal file
View File

@ -0,0 +1,46 @@
# utils/process_data.py
import json
import sys
from .paths import get_project_path
from sacrebleu import sentence_bleu as bleu
def build_reference_map(dataset_path: str) -> dict[str, list[str]]:
ref_map = {}
data = json.loads(open(dataset_path).read())
for entry in data['entries']:
id_ = entry['metadata']['id']
comments = entry['comments']
ref_map[id_] = [c['body'] for c in comments]
return ref_map
REFERENCE_MAP = build_reference_map(str(get_project_path('../data/dataset.json')))
def evaluate_comments(answers: dict[str, str], percent_cb):
total = len(answers)
results = {}
for i, (id_, gen) in enumerate(answers.items(), 1):
if id_ not in REFERENCE_MAP:
print(f'id: "{id_}" is not present in the dataset', file=sys.stderr)
continue
paraphrases = REFERENCE_MAP[id_]
max_score = 0.0
scores = []
for p in paraphrases:
score = bleu(gen, [p]).score
scores.append(score)
max_score = max(max_score, score)
results[id_] = {'max_bleu_score': max_score, 'bleu_scores': scores, 'proposed_comment': gen}
percent_cb(int(i / total * 100))
return results
def evaluate_refinement(answers: dict[str, str], percent_cb):
total = len(answers)
for i, (key, value) in enumerate(answers.items(), 1):
print(f"Processing {key}: {value}...")
# time.sleep(1)
print("Done")
percent_cb(int(i / total * 100))