diff --git a/src/routes/answers.py b/src/routes/answers.py index 6a1c20c..d1f5458 100644 --- a/src/routes/answers.py +++ b/src/routes/answers.py @@ -7,6 +7,8 @@ from utils.observer import SocketObserver, Status, Subject, request2status import functools import json, uuid +from utils.queue_manager import QueueManager + router = Blueprint('answers', __name__, url_prefix='/answers') ALLOWED_EXT = {'json'} @@ -47,6 +49,9 @@ def validate_json_format_for_code_refinement(data: str) -> dict[str, dict[str, s raise InvalidJsonFormatError() +QUEUE_MANAGER = QueueManager(1) + + def handler(type_: str, validate_json: Callable, evaluate_submission: Callable): file = request.files.get('file') if file is None or file.filename is None or file.filename.split('.')[-1] not in ALLOWED_EXT: @@ -71,7 +76,7 @@ def handler(type_: str, validate_json: Callable, evaluate_submission: Callable): obs = SocketObserver(sid, socket_emit) subject.registerObserver(obs) - subject.launch_task(validated) + QUEUE_MANAGER.submit(subject, validated) url = url_for(f".status", id=process_id, _external=True) return jsonify( { @@ -124,6 +129,8 @@ def status(id): obs.updatePercentage(subject.percent) subject.registerObserver(obs) return jsonify({"status": "processing", "percent": subject.percent}) + elif subject.status == Status.WAITING: + return jsonify({"status": "waiting", "queue_position": QUEUE_MANAGER.get_position(id)}) elif subject.status == Status.CREATED: return jsonify({"status": "created"}) diff --git a/src/utils/observer.py b/src/utils/observer.py index 8d393f3..4d7e7e2 100644 --- a/src/utils/observer.py +++ b/src/utils/observer.py @@ -6,6 +6,7 @@ from typing import Callable, Optional, Set, Any class Status(Enum): CREATED = "created" + WAITING = "waiting" PROCESSING = "processing" COMPLETE = "complete" @@ -65,19 +66,5 @@ class Subject: self.results = results # TODO: maybe save results to disk here? - def launch_task(self, *args, **kwargs): - self.status = Status.PROCESSING - t = Thread( - target=self.task, - args=args, - kwargs={ - **kwargs, - "percent_cb": self.notifyPercentage, - "complete_cb": self.notifyComplete, - }, - daemon=True, - ) - t.start() - request2status: dict[str, Subject] = {} diff --git a/src/utils/queue_manager.py b/src/utils/queue_manager.py new file mode 100644 index 0000000..63852ee --- /dev/null +++ b/src/utils/queue_manager.py @@ -0,0 +1,46 @@ +from concurrent.futures import ThreadPoolExecutor +from collections import deque +from utils.observer import Subject, Status + + +class QueueManager: + """ + Manages a queue of Subjects, handling status transitions and allowing position queries: + CREATED -> WAITING -> PROCESSING -> COMPLETE + """ + + def __init__(self, max_workers: int = 5) -> None: + self.executor = ThreadPoolExecutor(max_workers=max_workers) + self.wait_queue: deque[str] = deque() + + def submit(self, subject: Subject, *args, **kwargs) -> None: + subject.status = Status.WAITING + # Add to waiting queue + self.wait_queue.append(subject.id) + # Schedule the task on the executor + self.executor.submit(self._run, subject, *args, **kwargs) + + def get_position(self, subject_id: str) -> int: + """ + Returns 1-based position in waiting queue, or 0 if not waiting. + """ + try: + # index returns 0-based, so +1 + return self.wait_queue.index(subject_id) + 1 + except ValueError: + return 0 + + def _run(self, subject: Subject, *args, **kwargs) -> None: + # Remove from waiting queue as it's now processing + try: + self.wait_queue.remove(subject.id) + except ValueError: + pass + subject.status = Status.PROCESSING + # Execute the user-defined task synchronously in this worker thread + subject.task( + *args, + percent_cb=subject.notifyPercentage, + complete_cb=subject.notifyComplete, + **kwargs, + )