| """ |
| AathraOS Junction Signal API β OPTIMISED BUILD |
| ================================================= |
| Key changes vs v1: |
| β’ Inference at 320Γ180 (~4Γ faster than 640Γ360) |
| β’ imgsz=320 passed directly to YOLO (no CPU upscale overhead) |
| β’ Producer-Consumer frame queue per lane: reader thread β inference thread |
| β cv2.VideoCapture reads at full speed; analyzer drains queue independently |
| β’ Frames piggybacked on WebSocket broadcast (no separate REST poll needed) |
| β’ asyncio.get_running_loop() + run_coroutine_threadsafe (safe on Py3.10+) |
| β’ base64 imported once at module level |
| β’ JPEG quality lowered to 55 for wire speed (still plenty for HUD display) |
| β’ Model loaded with half=False, device='cpu' explicit β avoids silent fallback |
| β’ Lock-free frame store: only a Python assignment (atomic in CPython) |
| """ |
|
|
| import os |
| import cv2 |
| import time |
| import base64 |
| import asyncio |
| import threading |
| from collections import deque |
| from typing import Optional, Dict, List |
| from datetime import datetime, timezone |
|
|
| os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics" |
| os.environ["YOLO_VERBOSE"] = "False" |
| os.environ["YOLO_UPDATE_CHECK"] = "False" |
|
|
| from fastapi import FastAPI, UploadFile, File, WebSocket, WebSocketDisconnect |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
| import uvicorn |
| import numpy as np |
| from ultralytics import YOLO |
| import shutil |
|
|
| |
| INFER_W, INFER_H = 320, 192 |
| JPEG_QUALITY = 55 |
| TARGET_INFER_FPS = 6 |
| FRAME_QUEUE_MAX = 2 |
| CYCLE_INTERVAL = 5.0 |
| EMERGENCY_HOLD = 30 |
| BASE_TIME = 10 |
| DENSITY_FACTOR = 0.8 |
| MIN_GREEN = 8 |
| MAX_GREEN = 60 |
|
|
| YOLO_CONF = 0.30 |
| YOLO_IOU = 0.45 |
| YOLO_IMGSZ = 320 |
|
|
| PCU_WEIGHTS = {0: 0.0, 1: 0.5, 2: 1.0, 3: 0.5, 5: 3.0, 7: 3.0} |
| EMERGENCY_CLASSES = {"ambulance", "fire truck", "firetruck", "emergency"} |
| LANES = ["north", "south", "east", "west"] |
|
|
| |
| app = FastAPI(title="AathraOS Junction Signal API β Optimised") |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| _model: Optional[YOLO] = None |
| _model_lock = threading.Lock() |
|
|
| def get_model() -> YOLO: |
| global _model |
| with _model_lock: |
| if _model is None: |
| path = os.path.join(os.path.dirname(__file__), "yolov8n.pt") |
| _model = YOLO(path) |
| |
| dummy = np.zeros((INFER_H, INFER_W, 3), dtype=np.uint8) |
| _model.predict(dummy, imgsz=YOLO_IMGSZ, verbose=False, |
| conf=YOLO_CONF, iou=YOLO_IOU) |
| print("[YOLO] Model loaded and warmed up.") |
| return _model |
|
|
|
|
| |
| class LaneState: |
| """All mutable fields are written by exactly one thread except frame_b64.""" |
| __slots__ = ( |
| "name", "vehicle_count", "pcu_score", "breakdown", |
| "emergency_detected", "is_processing", "frame_b64", |
| "infer_fps", "display_fps", "frames_captured", "frames_inferred", |
| "raw_queue", "stop_event", "reader_thread", "infer_thread", |
| ) |
|
|
| def __init__(self, name: str): |
| self.name = name |
| self.vehicle_count = 0 |
| self.pcu_score = 0.0 |
| self.breakdown: Dict[str, int] = {} |
| self.emergency_detected= False |
| self.is_processing = False |
| self.frame_b64: Optional[str] = None |
| self.infer_fps = 0.0 |
| self.display_fps = 0.0 |
| self.frames_captured = 0 |
| self.frames_inferred = 0 |
| self.raw_queue: deque = deque(maxlen=FRAME_QUEUE_MAX) |
| self.stop_event = threading.Event() |
| self.reader_thread: Optional[threading.Thread] = None |
| self.infer_thread: Optional[threading.Thread] = None |
|
|
| lane_states: Dict[str, LaneState] = {ln: LaneState(ln) for ln in LANES} |
|
|
|
|
| |
| class SignalState: |
| __slots__ = ("signals", "active_lane", "green_duration", "emergency_mode", |
| "emergency_lane", "emergency_until", "cycle_count", "last_update") |
|
|
| def __init__(self): |
| self.signals: Dict[str, str] = {l: "RED" for l in LANES} |
| self.active_lane: Optional[str] = None |
| self.green_duration = 0.0 |
| self.emergency_mode = False |
| self.emergency_lane: Optional[str] = None |
| self.emergency_until= 0.0 |
| self.cycle_count = 0 |
| self.last_update = datetime.now(timezone.utc).isoformat() |
|
|
| signal_state = SignalState() |
| signal_lock = threading.Lock() |
|
|
| |
| ws_clients: List[WebSocket] = [] |
| ws_lock = threading.Lock() |
| _event_loop: Optional[asyncio.AbstractEventLoop] = None |
|
|
|
|
| |
| _ENCODE_PARAMS = [cv2.IMWRITE_JPEG_QUALITY, JPEG_QUALITY] |
|
|
| def encode_frame(frame: np.ndarray) -> str: |
| """Fast JPEG β base64 string, reuses encode params list.""" |
| _, buf = cv2.imencode(".jpg", frame, _ENCODE_PARAMS) |
| return base64.b64encode(buf).decode("ascii") |
|
|
|
|
| def analyze_frame(frame: np.ndarray, m: YOLO): |
| """ |
| Run YOLO inference on a pre-resized frame. |
| Returns (pcu, vcount, breakdown, emergency, annotated_frame). |
| """ |
| results = m.predict( |
| frame, |
| imgsz=YOLO_IMGSZ, |
| conf=YOLO_CONF, |
| iou=YOLO_IOU, |
| verbose=False, |
| stream=False, |
| ) |
|
|
| pcu, vcount = 0.0, 0 |
| breakdown: Dict[str, int] = {} |
| emergency = False |
| annotated = frame.copy() |
|
|
| if results and results[0].boxes is not None: |
| boxes = results[0].boxes |
| cls_arr = boxes.cls.cpu().numpy().astype(int) |
| conf_arr = boxes.conf.cpu().numpy() |
| xyxy_arr = boxes.xyxy.cpu().numpy().astype(int) |
|
|
| for i, (cls_id, conf, xyxy) in enumerate(zip(cls_arr, conf_arr, xyxy_arr)): |
| cls_name = m.names.get(int(cls_id), "").lower() |
|
|
| if any(e in cls_name for e in EMERGENCY_CLASSES): |
| emergency = True |
|
|
| weight = PCU_WEIGHTS.get(int(cls_id), 0.0) |
| pcu += weight |
| if int(cls_id) != 0: |
| vcount += 1 |
| label = cls_name or str(cls_id) |
| breakdown[label] = breakdown.get(label, 0) + 1 |
|
|
| x1, y1, x2, y2 = xyxy |
| color = (30, 30, 255) if emergency else (20, 220, 100) |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 1) |
| cv2.putText( |
| annotated, |
| f"{cls_name} {conf:.2f}", |
| (x1, max(y1 - 4, 0)), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.38, color, 1, |
| ) |
|
|
| return pcu, vcount, breakdown, emergency, annotated |
|
|
|
|
| |
| def _reader_loop(state: LaneState, video_path: str): |
| """ |
| Reads frames from video as fast as possible and pushes into the deque. |
| The deque has maxlen=FRAME_QUEUE_MAX so old frames are auto-dropped, |
| guaranteeing near-real-time content for the inference thread. |
| """ |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| print(f"[{state.name}] Reader: cannot open {video_path}") |
| return |
|
|
| |
| cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) |
|
|
| t0 = time.perf_counter() |
| captured = 0 |
|
|
| while not state.stop_event.is_set(): |
| ret, frame = cap.read() |
| if not ret: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, 0) |
| continue |
|
|
| |
| small = cv2.resize(frame, (INFER_W, INFER_H), interpolation=cv2.INTER_LINEAR) |
| state.raw_queue.append(small) |
|
|
| captured += 1 |
| state.frames_captured = captured |
| elapsed = time.perf_counter() - t0 |
| state.display_fps = round(captured / elapsed, 1) if elapsed > 0 else 0 |
|
|
| cap.release() |
| print(f"[{state.name}] Reader stopped.") |
|
|
|
|
| |
| def _infer_loop(state: LaneState): |
| """ |
| Drains frames from the queue, runs YOLO, updates state. |
| Sleeps to cap inference rate at TARGET_INFER_FPS, avoiding runaway CPU. |
| """ |
| m = get_model() |
| min_interval = 1.0 / TARGET_INFER_FPS |
| t0 = time.perf_counter() |
| inferred = 0 |
|
|
| state.is_processing = True |
| try: |
| while not state.stop_event.is_set(): |
| t_start = time.perf_counter() |
|
|
| |
| if not state.raw_queue: |
| time.sleep(0.01) |
| continue |
|
|
| frame = state.raw_queue[-1] |
|
|
| pcu, vcount, breakdown, emergency, annotated = analyze_frame(frame, m) |
|
|
| |
| state.pcu_score = pcu |
| state.vehicle_count = vcount |
| state.breakdown = breakdown |
| state.emergency_detected= emergency |
| state.frame_b64 = encode_frame(annotated) |
|
|
| inferred += 1 |
| state.frames_inferred = inferred |
| elapsed = time.perf_counter() - t0 |
| state.infer_fps = round(inferred / elapsed, 1) if elapsed > 0 else 0 |
|
|
| |
| spent = time.perf_counter() - t_start |
| sleep_for = min_interval - spent |
| if sleep_for > 0: |
| time.sleep(sleep_for) |
|
|
| finally: |
| state.is_processing = False |
| state.frame_b64 = None |
| print(f"[{state.name}] Infer stopped.") |
|
|
|
|
| |
| def signal_decision_loop(loop: asyncio.AbstractEventLoop): |
| while True: |
| time.sleep(CYCLE_INTERVAL) |
| now = time.time() |
|
|
| with signal_lock: |
| |
| em_lane = next( |
| (ln for ln, ls in lane_states.items() if ls.emergency_detected), |
| None, |
| ) |
| if em_lane: |
| signal_state.emergency_mode = True |
| signal_state.emergency_lane = em_lane |
| signal_state.emergency_until = now + EMERGENCY_HOLD |
| elif now < signal_state.emergency_until: |
| em_lane = signal_state.emergency_lane |
| else: |
| signal_state.emergency_mode = False |
| signal_state.emergency_lane = None |
|
|
| if em_lane: |
| for ln in LANES: |
| signal_state.signals[ln] = "GREEN" if ln == em_lane else "RED" |
| signal_state.active_lane = em_lane |
| signal_state.green_duration = float(EMERGENCY_HOLD) |
| else: |
| scores = {ln: lane_states[ln].pcu_score for ln in LANES} |
| best = max(scores, key=scores.get) |
| gt = min(MAX_GREEN, max(MIN_GREEN, BASE_TIME + scores[best] * DENSITY_FACTOR)) |
| for ln in LANES: |
| signal_state.signals[ln] = "GREEN" if ln == best else "RED" |
| signal_state.active_lane = best |
| signal_state.green_duration = round(gt, 1) |
|
|
| signal_state.cycle_count += 1 |
| signal_state.last_update = datetime.now(timezone.utc).isoformat() |
|
|
| |
| payload = _build_payload(include_frames=True) |
| asyncio.run_coroutine_threadsafe(_broadcast(payload), loop) |
|
|
|
|
| |
| async def _broadcast(payload: dict): |
| with ws_lock: |
| clients = list(ws_clients) |
| dead = [] |
| for ws in clients: |
| try: |
| await ws.send_json(payload) |
| except Exception: |
| dead.append(ws) |
| if dead: |
| with ws_lock: |
| for ws in dead: |
| if ws in ws_clients: |
| ws_clients.remove(ws) |
|
|
|
|
| |
| def _build_payload(include_frames: bool = False) -> dict: |
| with signal_lock: |
| sigs = dict(signal_state.signals) |
| active = signal_state.active_lane |
| em = signal_state.emergency_mode |
| em_lane = signal_state.emergency_lane |
| green_dur = signal_state.green_duration |
| cycle = signal_state.cycle_count |
| last_upd = signal_state.last_update |
|
|
| lanes_data = {} |
| for ln, ls in lane_states.items(): |
| entry = { |
| "vehicle_count": ls.vehicle_count, |
| "pcu_score": round(ls.pcu_score, 2), |
| "breakdown": dict(ls.breakdown), |
| "emergency_detected": ls.emergency_detected, |
| "is_processing": ls.is_processing, |
| "fps": ls.infer_fps, |
| "display_fps": ls.display_fps, |
| } |
| if include_frames: |
| entry["frame_b64"] = ls.frame_b64 |
| lanes_data[ln] = entry |
|
|
| return { |
| "signals": sigs, |
| "active_lane": active, |
| "emergency_mode": em, |
| "emergency_lane": em_lane, |
| "green_duration": green_dur, |
| "cycle_count": cycle, |
| "last_update": last_upd, |
| "lanes": lanes_data, |
| } |
|
|
|
|
| |
| @app.on_event("startup") |
| async def startup_event(): |
| global _event_loop |
| _event_loop = asyncio.get_running_loop() |
|
|
| |
| def _preload(): |
| get_model() |
| print("[startup] Model ready.") |
| threading.Thread(target=_preload, daemon=True).start() |
|
|
| |
| threading.Thread( |
| target=signal_decision_loop, args=(_event_loop,), daemon=True |
| ).start() |
|
|
| print("AathraOS Junction Signal Engine (optimised) started.") |
|
|
|
|
| |
| @app.get("/") |
| async def root(): |
| return {"status": "ok", "service": "AathraOS Junction Signal API (optimised)"} |
|
|
|
|
| @app.post("/junction/upload/{lane}") |
| async def upload_lane_feed(lane: str, file: UploadFile = File(...)): |
| if lane not in LANES: |
| return JSONResponse(400, {"error": f"Lane must be one of {LANES}"}) |
|
|
| os.makedirs("data/junction", exist_ok=True) |
| save_path = f"data/junction/{lane}_{file.filename}" |
|
|
| with open(save_path, "wb") as f: |
| shutil.copyfileobj(file.file, f) |
|
|
| ls = lane_states[lane] |
|
|
| |
| ls.stop_event.set() |
| for t in (ls.reader_thread, ls.infer_thread): |
| if t and t.is_alive(): |
| t.join(timeout=3.0) |
|
|
| ls.stop_event.clear() |
| ls.raw_queue.clear() |
| ls.frames_captured = 0 |
| ls.frames_inferred = 0 |
|
|
| |
| rt = threading.Thread(target=_reader_loop, args=(ls, save_path), daemon=True, name=f"reader-{lane}") |
| it = threading.Thread(target=_infer_loop, args=(ls,), daemon=True, name=f"infer-{lane}") |
| ls.reader_thread = rt |
| ls.infer_thread = it |
| rt.start() |
| it.start() |
|
|
| return {"message": f"Lane {lane} started β reader + infer threads active for {file.filename}"} |
|
|
|
|
| @app.post("/junction/stop/{lane}") |
| async def stop_lane(lane: str): |
| if lane not in LANES: |
| return JSONResponse(400, {"error": f"Lane must be one of {LANES}"}) |
| lane_states[lane].stop_event.set() |
| return {"message": f"Lane {lane} stop signal sent."} |
|
|
|
|
| @app.get("/junction/status") |
| async def get_status(): |
| return _build_payload(include_frames=False) |
|
|
|
|
| @app.get("/junction/frame/{lane}") |
| async def get_frame(lane: str): |
| if lane not in LANES: |
| return JSONResponse(400, {"error": "Unknown lane"}) |
| return {"lane": lane, "frame_b64": lane_states[lane].frame_b64} |
|
|
|
|
| @app.get("/junction/frames") |
| async def get_all_frames(): |
| """Lightweight frame-only endpoint for REST fallback.""" |
| return {ln: lane_states[ln].frame_b64 for ln in LANES} |
|
|
|
|
| @app.websocket("/ws/junction") |
| async def websocket_endpoint(ws: WebSocket): |
| await ws.accept() |
| with ws_lock: |
| ws_clients.append(ws) |
|
|
| |
| await ws.send_json(_build_payload(include_frames=True)) |
|
|
| try: |
| while True: |
| |
| await asyncio.sleep(1.0 / TARGET_INFER_FPS) |
| payload = _build_payload(include_frames=True) |
| await ws.send_json(payload) |
| except (WebSocketDisconnect, Exception): |
| pass |
| finally: |
| with ws_lock: |
| if ws in ws_clients: |
| ws_clients.remove(ws) |
|
|
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", 8001)) |
| uvicorn.run("junction_api:app", host="0.0.0.0", port=port, reload=False, workers=1) |
|
|
|
|