mirror of
https://github.com/karma-riuk/crab-webapp.git
synced 2025-07-05 14:18:12 +02:00
ported backend to python
This commit is contained in:
@ -1,92 +0,0 @@
|
||||
import { jest } from '@jest/globals';
|
||||
import express from 'express';
|
||||
import request from 'supertest';
|
||||
import { join } from 'path';
|
||||
import { fileURLToPath } from 'url';
|
||||
import { dirname } from 'path';
|
||||
import datasetsRouter from '../datasets.js';
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
// Mock the paths utility
|
||||
jest.mock('../../utils/paths.js', () => ({
|
||||
getProjectPath: (path) => join(__dirname, '../../..', path)
|
||||
}));
|
||||
|
||||
// Create Express app for testing
|
||||
const app = express();
|
||||
app.use('/datasets', datasetsRouter);
|
||||
|
||||
describe('Datasets Router', () => {
|
||||
// Mock environment variables
|
||||
const originalEnv = process.env;
|
||||
beforeEach(() => {
|
||||
jest.resetModules();
|
||||
process.env = { ...originalEnv };
|
||||
process.env.DATA_DIR = './test-data';
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe('GET /download/:dataset', () => {
|
||||
it('should return 400 for invalid dataset name', async () => {
|
||||
const response = await request(app)
|
||||
.get('/datasets/download/invalid_dataset')
|
||||
.expect(400);
|
||||
|
||||
expect(response.body).toEqual({
|
||||
error: 'Invalid dataset name'
|
||||
});
|
||||
});
|
||||
|
||||
it('should download comment_generation without context', async () => {
|
||||
const response = await request(app)
|
||||
.get('/datasets/download/comment_generation')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-type']).toBe('application/zip');
|
||||
expect(response.headers['content-disposition']).toContain('comment_generation_no_context.zip');
|
||||
});
|
||||
|
||||
it('should download comment_generation with context', async () => {
|
||||
const response = await request(app)
|
||||
.get('/datasets/download/comment_generation')
|
||||
.query({ withContext: true })
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-type']).toBe('application/zip');
|
||||
expect(response.headers['content-disposition']).toContain('comment_generation_with_context.zip');
|
||||
});
|
||||
|
||||
it('should download code_refinement without context', async () => {
|
||||
const response = await request(app)
|
||||
.get('/datasets/download/code_refinement')
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-type']).toBe('application/zip');
|
||||
expect(response.headers['content-disposition']).toContain('code_refinement_no_context.zip');
|
||||
});
|
||||
|
||||
it('should download code_refinement with context', async () => {
|
||||
const response = await request(app)
|
||||
.get('/datasets/download/code_refinement')
|
||||
.query({ withContext: true })
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-type']).toBe('application/zip');
|
||||
expect(response.headers['content-disposition']).toContain('code_refinement_with_context.zip');
|
||||
});
|
||||
|
||||
it('should handle JSON boolean for withContext parameter', async () => {
|
||||
const response = await request(app)
|
||||
.get('/datasets/download/comment_generation')
|
||||
.query({ withContext: 'true' })
|
||||
.expect(200);
|
||||
|
||||
expect(response.headers['content-disposition']).toContain('comment_generation_with_context.zip');
|
||||
});
|
||||
});
|
||||
});
|
@ -1,129 +0,0 @@
|
||||
import { Router } from "express";
|
||||
import multer from "multer";
|
||||
import { InvalidJsonFormatError } from "../utils/errors.js";
|
||||
import { evaluate_comments } from "../utils/process_data.js";
|
||||
|
||||
const router = Router();
|
||||
|
||||
// Configure multer for file uploads
|
||||
const upload = multer({
|
||||
storage: multer.memoryStorage(),
|
||||
limits: {
|
||||
fileSize: 200 * 1024 * 1024, // 200MB limit, since the comement gen is 147MB (deflated)
|
||||
},
|
||||
fileFilter: (_req, file, cb) => {
|
||||
// Accept only JSON files
|
||||
if (file.mimetype === "application/json") {
|
||||
cb(null, true);
|
||||
} else {
|
||||
cb(new Error("Only JSON files are allowed"));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// Helper function to validate JSON format
|
||||
const validateJsonFormat = (data) => {
|
||||
try {
|
||||
const parsed = JSON.parse(data);
|
||||
// Check if it's an object
|
||||
if (
|
||||
typeof parsed !== "object" ||
|
||||
parsed === null ||
|
||||
Array.isArray(parsed)
|
||||
) {
|
||||
throw new InvalidJsonFormatError(
|
||||
"Submitted json doesn't contain an object",
|
||||
);
|
||||
}
|
||||
// Check if all values are strings
|
||||
if (
|
||||
!Object.values(parsed).every((value) => typeof value === "string")
|
||||
) {
|
||||
throw new InvalidJsonFormatError(
|
||||
"Submitted json object must only be str -> str. Namely id -> comment",
|
||||
);
|
||||
}
|
||||
return parsed;
|
||||
} catch (error) {
|
||||
if (error instanceof InvalidJsonFormatError) {
|
||||
throw error;
|
||||
}
|
||||
throw new InvalidJsonFormatError("Invalid JSON format");
|
||||
}
|
||||
};
|
||||
|
||||
router.post("/submit/comments", upload.single("file"), async (req, res) => {
|
||||
try {
|
||||
if (!req.file) {
|
||||
return res.status(400).json({ error: "No file uploaded" });
|
||||
}
|
||||
|
||||
const fileContent = req.file.buffer.toString();
|
||||
let validatedData;
|
||||
|
||||
try {
|
||||
validatedData = validateJsonFormat(fileContent);
|
||||
} catch (error) {
|
||||
if (error instanceof InvalidJsonFormatError) {
|
||||
return res.status(400).json({
|
||||
error: "Invalid JSON format",
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
const io = req.app.get("io");
|
||||
const header = req.get("X-Socket-Id");
|
||||
const socketId = header && header.trim();
|
||||
if (socketId && io.sockets.sockets.has(socketId)) {
|
||||
io.to(socketId).emit("successul-upload");
|
||||
io.to(socketId).emit("started-processing");
|
||||
}
|
||||
|
||||
const results = evaluate_comments(validatedData, (percent) => {
|
||||
if (!(socketId && io.sockets.sockets.has(socketId))) return;
|
||||
|
||||
io.to(socketId).emit("progress", { percent });
|
||||
});
|
||||
res.status(200).json(results);
|
||||
} catch (error) {
|
||||
console.error("Error processing submission:", error);
|
||||
res.status(500).json({ error: "Error processing submission" });
|
||||
}
|
||||
});
|
||||
|
||||
router.post("/submit/refinement", upload.single("file"), async (req, res) => {
|
||||
try {
|
||||
if (!req.file) {
|
||||
return res.status(400).json({ error: "No file uploaded" });
|
||||
}
|
||||
|
||||
const fileContent = req.file.buffer.toString();
|
||||
let validatedData;
|
||||
|
||||
try {
|
||||
validatedData = validateJsonFormat(fileContent);
|
||||
} catch (error) {
|
||||
if (error instanceof InvalidJsonFormatError) {
|
||||
return res.status(400).json({
|
||||
error: "Invalid JSON format",
|
||||
message: error.message,
|
||||
});
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
socket.emit("started-processing");
|
||||
evaluate_comments(validatedData);
|
||||
res.status(200).json({
|
||||
message: "Answer submitted successfully",
|
||||
data: validatedData,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error processing submission:", error);
|
||||
res.status(500).json({ error: "Error processing submission" });
|
||||
}
|
||||
});
|
||||
|
||||
export default router;
|
55
src/routes/answers.py
Normal file
55
src/routes/answers.py
Normal file
@ -0,0 +1,55 @@
|
||||
# routes/answers.py
|
||||
from flask import Blueprint, request, jsonify, current_app
|
||||
from utils.errors import InvalidJsonFormatError
|
||||
from utils.process_data import evaluate_comments
|
||||
import json
|
||||
|
||||
router = Blueprint('answers', __name__, url_prefix='/answers')
|
||||
|
||||
ALLOWED_EXT = {'json'}
|
||||
|
||||
|
||||
def validate_json_format(data: str) -> dict[str, str]:
|
||||
try:
|
||||
obj = json.loads(data)
|
||||
if not isinstance(obj, dict):
|
||||
raise InvalidJsonFormatError("Submitted json doesn't contain an object")
|
||||
if not all(isinstance(v, str) for v in obj.values()):
|
||||
raise InvalidJsonFormatError(
|
||||
"Submitted json object must only be str -> str. Namely id -> comment"
|
||||
)
|
||||
return obj
|
||||
except InvalidJsonFormatError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise InvalidJsonFormatError()
|
||||
|
||||
|
||||
@router.route('/submit/comments', methods=['POST'])
|
||||
def submit_comments():
|
||||
file = request.files.get('file')
|
||||
if file is None or file.filename is None or file.filename.split('.')[-1] not in ALLOWED_EXT:
|
||||
return jsonify({'error': 'Only JSON files are allowed'}), 400
|
||||
data = file.read().decode()
|
||||
try:
|
||||
validated = validate_json_format(data)
|
||||
except InvalidJsonFormatError as e:
|
||||
return jsonify({'error': 'Invalid JSON format', 'message': str(e)}), 400
|
||||
|
||||
socketio = current_app.extensions['socketio']
|
||||
sid = request.headers.get('X-Socket-Id')
|
||||
if sid:
|
||||
socketio.emit('successful-upload', room=sid)
|
||||
socketio.emit('started-processing', room=sid)
|
||||
|
||||
results = evaluate_comments(
|
||||
validated, lambda p: socketio.emit('progress', {'percent': p}, room=sid)
|
||||
)
|
||||
return jsonify(results)
|
||||
|
||||
|
||||
@router.route('/submit/refinement', methods=['POST'])
|
||||
def submit_refinement():
|
||||
file = request.files.get('file')
|
||||
# similar to above
|
||||
return jsonify({'message': 'Answer submitted successfully'})
|
@ -1,33 +0,0 @@
|
||||
import { Router } from "express";
|
||||
import { join } from "path";
|
||||
import { getProjectPath } from "../utils/paths.js";
|
||||
|
||||
const router = Router();
|
||||
|
||||
// Environment variables for paths (all relative to project root)
|
||||
const DATA_DIR = getProjectPath("data");
|
||||
|
||||
const DATASETS = ["comment_generation", "code_refinement"];
|
||||
|
||||
router.get("/download/:dataset", async (req, res) => {
|
||||
const { dataset } = req.params;
|
||||
const withContext = req.query.withContext
|
||||
? JSON.parse(req.query.withContext)
|
||||
: false;
|
||||
|
||||
if (!DATASETS.includes(dataset)) {
|
||||
return res.status(400).json({ error: "Invalid dataset name" });
|
||||
}
|
||||
|
||||
const fileName = `${dataset}_${withContext ? "with_context" : "no_context"}.zip`;
|
||||
const filePath = join(DATA_DIR, fileName);
|
||||
|
||||
try {
|
||||
res.download(filePath);
|
||||
} catch (error) {
|
||||
console.error("Error serving file:", error);
|
||||
res.status(500).json({ error: "Error serving file" });
|
||||
}
|
||||
});
|
||||
|
||||
export default router;
|
17
src/routes/datasets.py
Normal file
17
src/routes/datasets.py
Normal file
@ -0,0 +1,17 @@
|
||||
# routes/datasets.py
|
||||
from flask import Blueprint, send_from_directory, request, jsonify
|
||||
from utils.paths import get_project_path
|
||||
|
||||
router = Blueprint('datasets', __name__, url_prefix='/datasets')
|
||||
|
||||
DATASETS = {'comment_generation', 'code_refinement'}
|
||||
DATA_DIR = get_project_path('../data')
|
||||
|
||||
|
||||
@router.route('/download/<dataset>')
|
||||
def download(dataset):
|
||||
if dataset not in DATASETS:
|
||||
return jsonify({'error': 'Invalid dataset name'}), 400
|
||||
with_ctx = request.args.get('withContext', 'false').lower() == 'true'
|
||||
fname = f"{dataset}_{'with_context' if with_ctx else 'no_context'}.zip"
|
||||
return send_from_directory(DATA_DIR, fname, as_attachment=True)
|
@ -1,23 +0,0 @@
|
||||
import { Router } from 'express';
|
||||
import datasetRoutes from './datasets.js';
|
||||
import answerRoutes from './answers.js';
|
||||
|
||||
const router = Router();
|
||||
|
||||
// Routes
|
||||
router.get('/', (_req, res) => {
|
||||
res.json({ message: 'Welcome to the Express backend!' });
|
||||
});
|
||||
|
||||
// Example route
|
||||
router.get('/api/hello', (_req, res) => {
|
||||
res.json({ message: 'Hello from the backend!' });
|
||||
});
|
||||
|
||||
// Dataset routes
|
||||
router.use('/datasets', datasetRoutes);
|
||||
|
||||
// Answer submission routes
|
||||
router.use('/answers', answerRoutes);
|
||||
|
||||
export default router;
|
16
src/routes/index.py
Normal file
16
src/routes/index.py
Normal file
@ -0,0 +1,16 @@
|
||||
# routes/index.py
|
||||
from flask import Blueprint, jsonify, current_app
|
||||
|
||||
|
||||
router = Blueprint('index', __name__)
|
||||
|
||||
|
||||
@router.route('/')
|
||||
def welcome():
|
||||
print("hello")
|
||||
return current_app.send_static_file('index.html')
|
||||
|
||||
|
||||
@router.route('/api/hello')
|
||||
def hello():
|
||||
return jsonify({'message': 'Hello from the backend!'})
|
@ -1,25 +0,0 @@
|
||||
import express, { json } from "express";
|
||||
import cors from "cors";
|
||||
import dotenv from "dotenv";
|
||||
import routes from "./routes/index.js";
|
||||
import { createSocketServer } from "./socket.js";
|
||||
|
||||
dotenv.config();
|
||||
|
||||
const app = express();
|
||||
const port = process.env.PORT || 3000;
|
||||
|
||||
// Middleware
|
||||
app.use(cors());
|
||||
app.use(json());
|
||||
|
||||
// Use routes
|
||||
app.use(express.static("public"));
|
||||
app.use("/", routes);
|
||||
|
||||
const server = createSocketServer(app);
|
||||
|
||||
// Start server
|
||||
server.listen(port, () => {
|
||||
console.log(`Server is running on port ${port}`);
|
||||
});
|
35
src/server.py
Normal file
35
src/server.py
Normal file
@ -0,0 +1,35 @@
|
||||
# server.py
|
||||
from flask import Flask
|
||||
from flask_cors import CORS
|
||||
from flask_socketio import SocketIO
|
||||
from routes.index import router as index_router
|
||||
from routes.answers import router as answers_router
|
||||
from routes.datasets import router as datasets_router
|
||||
|
||||
import os
|
||||
|
||||
app = Flask(__name__, static_folder='../public', static_url_path='/')
|
||||
CORS(app)
|
||||
|
||||
# Register routes
|
||||
app.register_blueprint(index_router) # serves '/' and '/api/hello'
|
||||
app.register_blueprint(answers_router) # mounts at '/answers'
|
||||
app.register_blueprint(datasets_router) # mounts at '/datasets'
|
||||
|
||||
|
||||
def init_socketio(app):
|
||||
socketio = SocketIO(app, cors_allowed_origins='*')
|
||||
|
||||
@socketio.on('connect')
|
||||
def _():
|
||||
print('Websocket client connected')
|
||||
|
||||
return socketio
|
||||
|
||||
|
||||
# Init socketio
|
||||
socketio = init_socketio(app)
|
||||
|
||||
if __name__ == '__main__':
|
||||
port = int(os.getenv('PORT', 3000))
|
||||
socketio.run(app, port=port)
|
@ -1,14 +0,0 @@
|
||||
import http from "http";
|
||||
import { Server } from "socket.io";
|
||||
|
||||
function onConnect(socket) {
|
||||
console.log("Websocket client connected:", socket.id);
|
||||
}
|
||||
|
||||
export function createSocketServer(app) {
|
||||
const httpServer = http.createServer(app);
|
||||
const io = new Server(httpServer);
|
||||
io.on("connection", onConnect);
|
||||
app.set("io", io);
|
||||
return httpServer;
|
||||
}
|
@ -1,125 +0,0 @@
|
||||
/*
|
||||
* Calculates BLEU score between a reference and candidate sentence.
|
||||
* Reference and candidate should be token arrays (e.g. split by whitespace).
|
||||
* We compute modified n-gram precisions for n=1..4, geometric mean, with smoothing (optional).
|
||||
* We include the brevity penalty.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Extracts n-grams from a sequence of tokens.
|
||||
* @param {string[]} tokens - Array of tokens.
|
||||
* @param {number} n - Size of the n-gram.
|
||||
* @returns {Object} Map from n-gram string to its count.
|
||||
*/
|
||||
function getNGramCounts(tokens, n) {
|
||||
const counts = Object.create(null);
|
||||
for (let i = 0; i + n <= tokens.length; i++) {
|
||||
const gram = tokens.slice(i, i + n).join(" ");
|
||||
counts[gram] = (counts[gram] || 0) + 1;
|
||||
}
|
||||
return counts;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes modified precision for a given n.
|
||||
* @param {string[]} reference - Reference token array.
|
||||
* @param {string[]} candidate - Candidate token array.
|
||||
* @param {number} n - n-gram order.
|
||||
* @returns {number} Modified n-gram precision.
|
||||
*/
|
||||
function modifiedPrecision(reference, candidate, n) {
|
||||
const refCounts = getNGramCounts(reference, n);
|
||||
const candCounts = getNGramCounts(candidate, n);
|
||||
let matchCount = 0;
|
||||
let totalCount = 0;
|
||||
|
||||
for (const gram in candCounts) {
|
||||
const countCand = candCounts[gram];
|
||||
const countRef = refCounts[gram] || 0;
|
||||
matchCount += Math.min(countCand, countRef);
|
||||
totalCount += countCand;
|
||||
}
|
||||
|
||||
// Avoid division by zero
|
||||
return totalCount === 0 ? 0 : matchCount / totalCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes brevity penalty.
|
||||
* @param {number} refLength - Length of reference sentence.
|
||||
* @param {number} candLength - Length of candidate sentence.
|
||||
* @returns {number} Brevity penalty.
|
||||
*/
|
||||
function brevityPenalty(refLength, candLength) {
|
||||
if (candLength > refLength) {
|
||||
return 1;
|
||||
}
|
||||
if (candLength === 0) {
|
||||
return 0;
|
||||
}
|
||||
return Math.exp(1 - refLength / candLength);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes BLEU score.
|
||||
* @param {string} refSentence - Reference sentence.
|
||||
* @param {string} candSentence - Candidate sentence.
|
||||
* @param {number} maxN - Maximum n-gram order (default=4).
|
||||
* @param {boolean} smooth - Whether to apply smoothing (default=false).
|
||||
* @returns {number} BLEU score between 0 and 1.
|
||||
*/
|
||||
export function bleu(refSentence, candSentence, maxN = 4, smooth = false) {
|
||||
const reference = refSentence.trim().split(/\s+/);
|
||||
const candidate = candSentence.trim().split(/\s+/);
|
||||
const refLen = reference.length;
|
||||
const candLen = candidate.length;
|
||||
|
||||
// count how many times we've hit a zero count so far
|
||||
const precisions = [];
|
||||
for (let n = 1; n <= maxN; n++) {
|
||||
let p = modifiedPrecision(reference, candidate, n);
|
||||
if (p === 0 && smooth) {
|
||||
p = 1 / Math.pow(candLen, n);
|
||||
}
|
||||
precisions.push(p);
|
||||
}
|
||||
|
||||
// Compute geometric mean of precisions
|
||||
// if any precision is zero (and no smoothing), BLEU=0
|
||||
if (precisions.some((p) => p === 0)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const logPrecisionSum =
|
||||
precisions.map((p) => Math.log(p)).reduce((a, b) => a + b, 0) / maxN;
|
||||
const geoMean = Math.exp(logPrecisionSum);
|
||||
|
||||
const bp = brevityPenalty(refLen, candLen);
|
||||
return bp * geoMean;
|
||||
}
|
||||
|
||||
// if __name__ == "__main__"
|
||||
if (process.argv[1] === import.meta.filename) {
|
||||
const test_pairs = [
|
||||
["the cat is on the mat", "the cat is on the mat"],
|
||||
["the cat is on the mat", "the the the the the the the"],
|
||||
["the cat is on the mat", "the cat on the mat"],
|
||||
["the cat is on the mat", "the cat is on the"],
|
||||
["the cat is on the mat", "foo bar baz qux"],
|
||||
[
|
||||
"The quick brown fox jumps over the lazy dog",
|
||||
"The quick brown dog jumps over the lazy fox",
|
||||
],
|
||||
[
|
||||
"This could be `static` to prevent any funkiness, i.e. attempting to use class state during the constructor or similar.",
|
||||
"This could be `static` to prevent any funkiness, i.e. attempting to use class state during the constructor or similar.",
|
||||
],
|
||||
];
|
||||
|
||||
for (const [reference, candidate] of test_pairs) {
|
||||
const score = bleu(reference, candidate, 4);
|
||||
console.log(`reference: ${reference}`);
|
||||
console.log(`candidate: ${candidate}`);
|
||||
console.log(`BLEU score: ${score.toFixed(4)}`);
|
||||
}
|
||||
}
|
@ -1,6 +0,0 @@
|
||||
export class InvalidJsonFormatError extends Error {
|
||||
constructor(message = 'JSON must be an object mapping strings to strings') {
|
||||
super(message);
|
||||
this.name = 'InvalidJsonFormatError';
|
||||
}
|
||||
}
|
4
src/utils/errors.py
Normal file
4
src/utils/errors.py
Normal file
@ -0,0 +1,4 @@
|
||||
class InvalidJsonFormatError(Exception):
|
||||
def __init__(self, message='JSON must be an object mapping strings to strings'):
|
||||
super().__init__(message)
|
||||
self.name = 'InvalidJsonFormatError'
|
@ -1,11 +0,0 @@
|
||||
import { fileURLToPath } from 'url';
|
||||
import { dirname, join } from 'path';
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
// Get the project root directory (2 levels up from src/utils)
|
||||
export const PROJECT_ROOT = join(__dirname, '../..');
|
||||
|
||||
// Helper function to create paths relative to project root
|
||||
export const getProjectPath = (relativePath) => join(PROJECT_ROOT, relativePath);
|
8
src/utils/paths.py
Normal file
8
src/utils/paths.py
Normal file
@ -0,0 +1,8 @@
|
||||
# utils/paths.py
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def get_project_path(relative_path: str) -> Path:
|
||||
return PROJECT_ROOT / relative_path
|
@ -1,58 +0,0 @@
|
||||
import fs from "fs";
|
||||
import { getProjectPath } from "../utils/paths.js";
|
||||
import { bleu } from "../utils/bleu.js";
|
||||
|
||||
function buildReferenceMap(dataset_path) {
|
||||
const referenceMap = {};
|
||||
const dataset = JSON.parse(fs.readFileSync(dataset_path));
|
||||
for (const entry of dataset.entries) {
|
||||
const id = entry.metadata.id;
|
||||
const comments = entry.comments;
|
||||
referenceMap[id] = comments.map((c) => c.body);
|
||||
}
|
||||
return referenceMap;
|
||||
}
|
||||
|
||||
const REFERENCE_MAP = buildReferenceMap(getProjectPath("data/dataset.json"));
|
||||
|
||||
export const evaluate_comments = (answers, percent_cb) => {
|
||||
const total = Object.keys(answers).length;
|
||||
let i = 0;
|
||||
const results = {};
|
||||
for (const [id, generated_comment] of Object.entries(answers)) {
|
||||
const n_tokens_generated = generated_comment.trim().split(/\s+/).length;
|
||||
if (!(id in REFERENCE_MAP)) {
|
||||
console.error(`id: "${id}" is not present in the dataset`);
|
||||
continue;
|
||||
}
|
||||
const paraphrases = REFERENCE_MAP[id];
|
||||
|
||||
let maxScore = 0;
|
||||
const scores = [];
|
||||
for (const paraphrase of paraphrases) {
|
||||
const n_tokens_paraphrase = paraphrase.trim().split(/\s+/).length;
|
||||
const max_n = Math.min(n_tokens_generated, n_tokens_paraphrase, 4);
|
||||
const score = bleu(paraphrase, generated_comment, max_n);
|
||||
scores.push(score);
|
||||
maxScore = Math.max(score, maxScore);
|
||||
}
|
||||
results[id] = {
|
||||
max_bleu_score: maxScore,
|
||||
bleu_scores: scores,
|
||||
proposed_comment: generated_comment,
|
||||
};
|
||||
percent_cb(Math.floor((++i / total) * 100));
|
||||
}
|
||||
return results;
|
||||
};
|
||||
|
||||
export const evaluate_refinement = (answers, percent_cb) => {
|
||||
const total = Object.keys(answers).length;
|
||||
let i = 0;
|
||||
for (const [key, value] of Object.entries(answers)) {
|
||||
console.log(`Processing ${key}: ${value}...`);
|
||||
// await new Promise((res) => setTimeout(res, 1000));
|
||||
console.log("Done");
|
||||
percent_cb(Math.floor((++i / total) * 100));
|
||||
}
|
||||
};
|
46
src/utils/process_data.py
Normal file
46
src/utils/process_data.py
Normal file
@ -0,0 +1,46 @@
|
||||
# utils/process_data.py
|
||||
import json
|
||||
import sys
|
||||
from .paths import get_project_path
|
||||
from sacrebleu import sentence_bleu as bleu
|
||||
|
||||
|
||||
def build_reference_map(dataset_path: str) -> dict[str, list[str]]:
|
||||
ref_map = {}
|
||||
data = json.loads(open(dataset_path).read())
|
||||
for entry in data['entries']:
|
||||
id_ = entry['metadata']['id']
|
||||
comments = entry['comments']
|
||||
ref_map[id_] = [c['body'] for c in comments]
|
||||
return ref_map
|
||||
|
||||
|
||||
REFERENCE_MAP = build_reference_map(str(get_project_path('../data/dataset.json')))
|
||||
|
||||
|
||||
def evaluate_comments(answers: dict[str, str], percent_cb):
|
||||
total = len(answers)
|
||||
results = {}
|
||||
for i, (id_, gen) in enumerate(answers.items(), 1):
|
||||
if id_ not in REFERENCE_MAP:
|
||||
print(f'id: "{id_}" is not present in the dataset', file=sys.stderr)
|
||||
continue
|
||||
paraphrases = REFERENCE_MAP[id_]
|
||||
max_score = 0.0
|
||||
scores = []
|
||||
for p in paraphrases:
|
||||
score = bleu(gen, [p]).score
|
||||
scores.append(score)
|
||||
max_score = max(max_score, score)
|
||||
results[id_] = {'max_bleu_score': max_score, 'bleu_scores': scores, 'proposed_comment': gen}
|
||||
percent_cb(int(i / total * 100))
|
||||
return results
|
||||
|
||||
|
||||
def evaluate_refinement(answers: dict[str, str], percent_cb):
|
||||
total = len(answers)
|
||||
for i, (key, value) in enumerate(answers.items(), 1):
|
||||
print(f"Processing {key}: {value}...")
|
||||
# time.sleep(1)
|
||||
print("Done")
|
||||
percent_cb(int(i / total * 100))
|
Reference in New Issue
Block a user