From 49b2606a53512fd52c26a4a9ec0e5edb732cc74a Mon Sep 17 00:00:00 2001 From: Karma Riuk Date: Sun, 18 May 2025 17:43:24 +0200 Subject: [PATCH] created the observers and handling them --- public/js/index.js | 98 ++++++++++++++++++++++----------------- src/routes/answers.py | 69 +++++++++++++++++++++------ src/utils/observer.py | 70 +++++++++++++++++++++++++++- src/utils/process_data.py | 8 +++- 4 files changed, 188 insertions(+), 57 deletions(-) diff --git a/public/js/index.js b/public/js/index.js index b3e144d..ebc82b8 100644 --- a/public/js/index.js +++ b/public/js/index.js @@ -27,6 +27,50 @@ document.getElementById("download-dataset").onclick = () => { window.location = url; }; +function populateCommentTable(results) { + commentResultsContainer.classList.remove("hidden"); + + const tbody = commentResultsContainer.querySelector("table tbody"); + tbody.innerHTML = ""; + + Object.entries(results).forEach(([id, info]) => { + const row = tbody.insertRow(); // create a new row + const idCell = row.insertCell(); // cell 1: id + const commentCell = row.insertCell(); + const scoreCell = row.insertCell(); + const span = document.createElement("span"); + + idCell.textContent = id; + span.className = "comment-cell"; + span.textContent = info["proposed_comment"]; + commentCell.appendChild(span); + scoreCell.textContent = info["max_bleu_score"].toFixed(2); + }); +} + +function populateRefinementTable(results) { + refinementResultsContainer.classList.remove("hidden"); + + const tbody = refinementResultsContainer.querySelector("table tbody"); + tbody.innerHTML = ""; + + Object.entries(results).forEach(([id, info]) => { + const row = tbody.insertRow(); // create a new row + const idCell = row.insertCell(); // cell 1: id + const compiledCell = row.insertCell(); + const testedCell = row.insertCell(); + + idCell.textContent = id; + compiledCell.textContent = info["compilation"] || false ? tick : cross; + testedCell.textContent = info["test"] || false ? tick : cross; + }); +} + +function handleRefinementAnswer(json) { + uploadStatusEl.style.color = "green"; + uploadStatusEl.textContent = json["id"]; +} + // Upload logic document.getElementById("upload-btn").onclick = async () => { uploadStatusEl.classList.add("hidden"); @@ -58,44 +102,10 @@ document.getElementById("upload-btn").onclick = async () => { return; } - results = json; - progressContainer.classList.add("hidden"); - commentResultsContainer.classList.add("hidden"); refinementResultsContainer.classList.add("hidden"); - const resultsContainer = - type === "comment" - ? commentResultsContainer - : refinementResultsContainer; - - resultsContainer.classList.remove("hidden"); - - const tbody = resultsContainer.querySelector("table tbody"); - tbody.innerHTML = ""; - - Object.entries(results).forEach(([id, info]) => { - const row = tbody.insertRow(); // create a new row - const idCell = row.insertCell(); // cell 1: id - idCell.textContent = id; - - if (type == "comment") { - const commentCell = row.insertCell(); - const scoreCell = row.insertCell(); - - const span = document.createElement("span"); - span.className = "comment-cell"; - span.textContent = info["proposed_comment"]; - commentCell.appendChild(span); - scoreCell.textContent = info["max_bleu_score"].toFixed(2); - } else { - const compiledCell = row.insertCell(); - const testedCell = row.insertCell(); - - compiledCell.textContent = - info["compilation"] || false ? tick : cross; - testedCell.textContent = info["test"] || false ? tick : cross; - } - }); + if (type === "comment") populateCommentTable(json); + else handleRefinementAnswer(json); }; [...document.getElementsByClassName("download-results")].forEach((e) => { @@ -126,11 +136,14 @@ socket.on("progress", (data) => { } }); -socket.on("started-processing", (data) => { +socket.on("started-processing", () => { setProgress(0); - uploadStatusEl.classList.remove("hidden"); - uploadStatusEl.style.color = "green"; - uploadStatusEl.textContent = data["id"]; +}); + +socket.on("complete", (results) => { + progressContainer.classList.add("hidden"); + refinementResultsContainer.classList.remove("hidden"); + populateRefinementTable(results); }); socket.on("successful-upload", () => { @@ -194,11 +207,12 @@ document.getElementById("request-status").onclick = async () => { if (!res.ok) { statusStatusEl.classList.remove("hidden"); statusStatusEl.style.color = "red"; - statusStatusEl.textContent = - json.error + (json.message ? ": " + json.message : ""); + statusStatusEl.textContent = json.message ? json.message : json.error; return; } statusStatusEl.classList.remove("hidden"); statusStatusEl.style.color = "green"; statusStatusEl.textContent = json["status"]; + + if (json.status == "complete") populateRefinementTable(json.results); }; diff --git a/src/routes/answers.py b/src/routes/answers.py index 01ebfb6..6145e8f 100644 --- a/src/routes/answers.py +++ b/src/routes/answers.py @@ -1,8 +1,10 @@ # routes/answers.py -from flask import Blueprint, request, jsonify, current_app +from threading import Thread +from flask import Blueprint, request, jsonify, current_app, url_for from utils.errors import InvalidJsonFormatError from utils.process_data import evaluate_comments, evaluate_refinement -from utils.observer import request2status +from utils.observer import SocketObserver, Status, Subject, request2status +import functools import json, uuid router = Blueprint('answers', __name__, url_prefix='/answers') @@ -68,6 +70,9 @@ def submit_comments(): return jsonify(results) +socket2observer = {} + + @router.route('/submit/refinement', methods=['POST']) def submit_refinement(): file = request.files.get('file') @@ -79,22 +84,60 @@ def submit_refinement(): except InvalidJsonFormatError as e: return jsonify({'error': 'Invalid JSON format', 'message': str(e)}), 400 - process_id = str(uuid.uuid4()) - request2status[process_id] = "processing" - socketio = current_app.extensions['socketio'] sid = request.headers.get('X-Socket-Id') + socket_emit = functools.partial(socketio.emit, room=sid) + + process_id = str(uuid.uuid4()) + subject = Subject(process_id, evaluate_refinement) + request2status[process_id] = subject + if sid: - socketio.emit('successful-upload', room=sid) - socketio.emit('started-processing', {"id": process_id}, room=sid) + socket_emit('successful-upload') + socket_emit('started-processing') + obs = SocketObserver(socket_emit) + socket2observer[sid] = obs + subject.registerObserver(obs) - results = evaluate_refinement( - validated, lambda p: socketio.emit('progress', {'percent': p}, room=sid) + t = Thread(target=subject.launch_task, args=(validated,), daemon=True) + t.start() + url = url_for(f".status", id=process_id, _external=True) + return jsonify( + { + "id": process_id, + "status_url": url, + "help_msg": "Check the status of this process at /answers/status/. Once the evaluation is complete, a call to this URL will return the results.", + } ) - return jsonify(results) - @router.route('/status/') -def request_status(id): - return jsonify({"status": request2status.get(id, "doens't exist")}) +def status(id): + if id not in request2status: + raise ValueError(f"Id {id} doesn't exist") + + subject = request2status[id] + if subject.status == Status.COMPLETE: + return jsonify({"status": "complete", "results": subject.results}) + elif subject.status == Status.PROCESSING: + socketio = current_app.extensions['socketio'] + sid = request.headers.get('X-Socket-Id') + socket_emit = functools.partial(socketio.emit, room=sid) + + request2status[id] = subject + if sid: + if sid in socket2observer: + raise AttributeError( + "You are already seeing the real-time progress of that request, please don't spam" + ) + + obs = SocketObserver(socket_emit) + socket2observer[sid] = obs + obs.updatePercentage(subject.percent) + subject.registerObserver(obs) + # if no socket, return current status + return jsonify({"status": "processing", "percent": subject.percent}) + elif subject.status == Status.CREATED: + return jsonify({"status": "created"}) + + raise Exception("This code should be unreachable") diff --git a/src/utils/observer.py b/src/utils/observer.py index 2c2ea7c..8d88c27 100644 --- a/src/utils/observer.py +++ b/src/utils/observer.py @@ -1 +1,69 @@ -request2status = {} +from abc import ABC, abstractmethod +from enum import Enum +from typing import Callable, Optional, Set, Any + + +class Status(Enum): + CREATED = "created" + PROCESSING = "processing" + COMPLETE = "complete" + + +class Observer(ABC): + @abstractmethod + def updatePercentage(self, percentage: float): + ... + + @abstractmethod + def updateComplete(self, results: dict): + ... + + +class SocketObserver(Observer): + def __init__(self, socket_emit: Callable[[str, Any], None]) -> None: + super().__init__() + self.socket_emit = socket_emit + + def updatePercentage(self, percentage: float): + self.socket_emit("progress", {'percent': percentage}) + + def updateComplete(self, results: dict): + self.socket_emit("complete", results) + + +class Subject: + # TODO: maybe have a process or thread pool here to implement the queue + def __init__(self, id: str, task: Callable) -> None: + self.id = id + self.observers: Set[Observer] = set() + self.status: Status = Status.CREATED + self.results: Optional[dict] = None + self.task = task + self.percent: float = -1 + + def registerObserver(self, observer: Observer) -> None: + self.observers.add(observer) + + def unregisterObserver(self, observer: Observer): + self.observers.remove(observer) + + def notifyPercentage(self, percentage: float): + self.percent = percentage + for observer in self.observers: + observer.updatePercentage(percentage) + + def notifyComplete(self, results: dict): + self.status = Status.COMPLETE + for observer in self.observers: + observer.updateComplete(results) + self.results = results + # TODO: maybe save results to disk here? + + def launch_task(self, *args, **kwargs): + self.status = Status.PROCESSING + self.task( + *args, **kwargs, percent_cb=self.notifyPercentage, complete_cb=self.notifyComplete + ) + + +request2status: dict[str, Subject] = {} diff --git a/src/utils/process_data.py b/src/utils/process_data.py index 6ce8a46..1495bb8 100644 --- a/src/utils/process_data.py +++ b/src/utils/process_data.py @@ -1,4 +1,5 @@ import sys +from typing_extensions import Callable from utils.handlers import get_build_handler from .paths import get_project_path from sacrebleu import sentence_bleu as bleu @@ -35,7 +36,11 @@ def evaluate_comments(answers: dict[str, str], percent_cb): return results -def evaluate_refinement(answers: dict[str, dict[str, str]], percent_cb): +def evaluate_refinement( + answers: dict[str, dict[str, str]], + percent_cb: Callable[[float], None] = lambda _: None, + complete_cb: Callable[[dict], None] = lambda _: None, +): n_answers = len(answers) n_steps = 4 # creating build handler + injecting the files in the repo + compilation + testing total_number_of_steps = n_answers * n_steps @@ -92,4 +97,5 @@ def evaluate_refinement(answers: dict[str, dict[str, str]], percent_cb): print(f"[INFO] Done with {id}...") + complete_cb(results) return results