diff --git a/public/js/index.js b/public/js/index.js index 9140d0b..09ec2db 100644 --- a/public/js/index.js +++ b/public/js/index.js @@ -66,11 +66,6 @@ function populateRefinementTable(results) { }); } -function handleRefinementAnswer(json) { - uploadStatusEl.style.color = "green"; - uploadStatusEl.textContent = json["id"]; -} - // Upload logic document.getElementById("upload-btn").onclick = async () => { uploadStatusEl.classList.add("hidden"); @@ -104,8 +99,9 @@ document.getElementById("upload-btn").onclick = async () => { commentResultsContainer.classList.add("hidden"); refinementResultsContainer.classList.add("hidden"); - if (type === "comment") populateCommentTable(json); - else handleRefinementAnswer(json); + + uploadStatusEl.style.color = "green"; + uploadStatusEl.textContent = json["id"]; }; [...document.getElementsByClassName("download-results")].forEach((e) => { @@ -136,10 +132,19 @@ socket.on("started-processing", () => { setProgress(0); }); -socket.on("complete", (results) => { +socket.on("complete", (data) => { + commentResultsContainer.classList.add("hidden"); + refinementResultsContainer.classList.add("hidden"); progressContainer.classList.add("hidden"); - refinementResultsContainer.classList.remove("hidden"); - populateRefinementTable(results); + if (data.type == "comment") { + commentResultsContainer.classList.remove("hidden"); + populateCommentTable(data.results); + } else if (data.type == "refinement") { + refinementResultsContainer.classList.remove("hidden"); + populateRefinementTable(data.results); + } else { + console.error(`Unknown type ${data.type}`); + } }); socket.on("successful-upload", () => { @@ -167,5 +172,11 @@ document.getElementById("request-status").onclick = async () => { statusStatusEl.style.color = "green"; statusStatusEl.textContent = json["status"]; - if (json.status == "complete") populateRefinementTable(json.results); + if (json.status == "complete") { + commentResultsContainer.classList.add("hidden"); + refinementResultsContainer.classList.add("hidden"); + if (json.type == "comment") populateCommentTable(json.results); + else if (json.type == "comment") populateRefinementTable(json.results); + else console.error(`Unknown type ${data.type}`); + } }; diff --git a/src/routes/answers.py b/src/routes/answers.py index 88de154..cc28167 100644 --- a/src/routes/answers.py +++ b/src/routes/answers.py @@ -1,5 +1,5 @@ # routes/answers.py -from threading import Thread +from typing import Callable from flask import Blueprint, request, jsonify, current_app, url_for from utils.errors import InvalidJsonFormatError from utils.process_data import evaluate_comments, evaluate_refinement @@ -47,40 +47,13 @@ def validate_json_format_for_code_refinement(data: str) -> dict[str, dict[str, s raise InvalidJsonFormatError() -@router.route('/submit/comment', methods=['POST']) -def submit_comments(): +def handler(type_: str, validate_json: Callable, evaluate_submission: Callable): file = request.files.get('file') if file is None or file.filename is None or file.filename.split('.')[-1] not in ALLOWED_EXT: return jsonify({'error': 'Only JSON files are allowed'}), 400 data = file.read().decode() try: - validated = validate_json_format_for_comment_gen(data) - except InvalidJsonFormatError as e: - return jsonify({'error': 'Invalid JSON format', 'message': str(e)}), 400 - - socketio = current_app.extensions['socketio'] - sid = request.headers.get('X-Socket-Id') - if sid: - socketio.emit('successful-upload', room=sid) - socketio.emit('started-processing', room=sid) - - results = evaluate_comments( - validated, lambda p: socketio.emit('progress', {'percent': p}, room=sid) - ) - return jsonify(results) - - -socket2observer = {} - - -@router.route('/submit/refinement', methods=['POST']) -def submit_refinement(): - file = request.files.get('file') - if file is None or file.filename is None or file.filename.split('.')[-1] not in ALLOWED_EXT: - return jsonify({'error': 'Only JSON files are allowed'}), 400 - data = file.read().decode() - try: - validated = validate_json_format_for_code_refinement(data) + validated = validate_json(data) except InvalidJsonFormatError as e: return jsonify({'error': 'Invalid JSON format', 'message': str(e)}), 400 @@ -89,7 +62,7 @@ def submit_refinement(): socket_emit = functools.partial(socketio.emit, room=sid) process_id = str(uuid.uuid4()) - subject = Subject(process_id, evaluate_refinement) + subject = Subject(process_id, type_, evaluate_submission) request2status[process_id] = subject if sid: @@ -110,6 +83,18 @@ def submit_refinement(): ) +@router.route('/submit/', methods=['POST']) +def submit_comments(task): + if task == "comment": + validator = validate_json_format_for_comment_gen + evaluator = evaluate_comments + else: + validator = validate_json_format_for_code_refinement + evaluator = evaluate_refinement + + return handler(task, validator, evaluator) + + @router.route('/status/') def status(id): if id not in request2status: diff --git a/src/utils/observer.py b/src/utils/observer.py index 6529701..bfa4a02 100644 --- a/src/utils/observer.py +++ b/src/utils/observer.py @@ -56,7 +56,7 @@ class Subject: def notifyComplete(self, results: dict): self.status = Status.COMPLETE for observer in self.observers: - observer.updateComplete(results) + observer.updateComplete({"type": self.type, "results": results}) self.results = results # TODO: maybe save results to disk here? diff --git a/src/utils/process_data.py b/src/utils/process_data.py index bf55cfb..452d4d7 100644 --- a/src/utils/process_data.py +++ b/src/utils/process_data.py @@ -12,7 +12,11 @@ REFERENCE_MAP = Dataset.from_json( ARCHIVES_ROOT = str(get_project_path('../data/archives')) -def evaluate_comments(answers: dict[str, str], percent_cb): +def evaluate_comments( + answers: dict[str, str], + percent_cb: Callable[[float], None] = lambda _: None, + complete_cb: Callable[[dict], None] = lambda _: None, +): total = len(answers) results = {} for i, (id_, gen) in enumerate(answers.items(), 1): @@ -33,6 +37,8 @@ def evaluate_comments(answers: dict[str, str], percent_cb): 'proposed_comment': gen, } percent_cb(int(i / total * 100)) + + complete_cb(results) return results