diff --git a/src/routes/answers.py b/src/routes/answers.py index cc28167..6a1c20c 100644 --- a/src/routes/answers.py +++ b/src/routes/answers.py @@ -68,8 +68,7 @@ def handler(type_: str, validate_json: Callable, evaluate_submission: Callable): if sid: socket_emit('successful-upload') socket_emit('started-processing') - obs = SocketObserver(socket_emit) - socket2observer[sid] = obs + obs = SocketObserver(sid, socket_emit) subject.registerObserver(obs) subject.launch_task(validated) @@ -102,7 +101,7 @@ def status(id): subject = request2status[id] if subject.status == Status.COMPLETE: - return jsonify({"status": "complete", "results": subject.results}) + return jsonify({"status": "complete", "type": subject.type, "results": subject.results}) elif subject.status == Status.PROCESSING: socketio = current_app.extensions['socketio'] sid = request.headers.get('X-Socket-Id') @@ -110,7 +109,7 @@ def status(id): request2status[id] = subject if sid: - if sid in socket2observer: + if sid in SocketObserver.socket2obs: return ( jsonify( { @@ -121,8 +120,7 @@ def status(id): 400, ) - obs = SocketObserver(socket_emit) - socket2observer[sid] = obs + obs = SocketObserver(sid, socket_emit) obs.updatePercentage(subject.percent) subject.registerObserver(obs) return jsonify({"status": "processing", "percent": subject.percent}) diff --git a/src/utils/observer.py b/src/utils/observer.py index bfa4a02..06ad8bb 100644 --- a/src/utils/observer.py +++ b/src/utils/observer.py @@ -21,21 +21,27 @@ class Observer(ABC): class SocketObserver(Observer): - def __init__(self, socket_emit: Callable[[str, Any], None]) -> None: + socket2obs: dict[str, "SocketObserver"] = {} + + def __init__(self, sid: str, socket_emit: Callable[[str, Any], None]) -> None: super().__init__() + self.sid = sid self.socket_emit = socket_emit + SocketObserver.socket2obs[self.sid] = self def updatePercentage(self, percentage: float): self.socket_emit("progress", {'percent': percentage}) def updateComplete(self, results: dict): self.socket_emit("complete", results) + SocketObserver.socket2obs.pop(self.sid) class Subject: # TODO: maybe have a process or thread pool here to implement the queue - def __init__(self, id: str, task: Callable) -> None: + def __init__(self, id: str, type_: str, task: Callable) -> None: self.id = id + self.type = type_ self.observers: Set[Observer] = set() self.status: Status = Status.CREATED self.results: Optional[dict] = None