Spaces:
Sleeping
Sleeping
| """ | |
| Minimal FastAPI WebSocket server for streaming audio chunks with ACK and a simple per-session ring buffer, | |
| plus a REST endpoint to request slice transcription using a background worker (stub ASR). | |
| Run locally: | |
| python -m uvicorn ws_server:app --host 0.0.0.0 --port 5001 | |
| Endpoints: | |
| WS /ws/stream/{session_id} | |
| GET /buffer_window/{session_id} | |
| POST /transcribe_slice | |
| GET /job/{job_id} | |
| Protocol (JSON frames): | |
| -> {"type":"audio.chunk","session_id":"...","seq":int,"t0_ms":int,"t1_ms":int,"mime":"audio/webm;codecs=opus","b64":"..."} | |
| <- {"type":"audio.ack","session_id":"...","seq":int,"backlog_ms":int} | |
| """ | |
| from __future__ import annotations | |
| import base64 | |
| import time | |
| from collections import deque, defaultdict | |
| from dataclasses import dataclass | |
| from typing import Deque, Dict, List, Optional | |
| import asyncio | |
| import uuid | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi import Body | |
| from pydantic import BaseModel | |
| class AudioChunk: | |
| t0_ms: int | |
| t1_ms: int | |
| mime: str | |
| data: bytes | |
| class SessionBuffer: | |
| """Simple time-retention ring using deque per session.""" | |
| def __init__(self, retention_ms: int = 10 * 60 * 1000): # 10 minutes | |
| self.retention_ms = retention_ms | |
| self.q: Deque[AudioChunk] = deque() | |
| def push(self, chunk: AudioChunk): | |
| self.q.append(chunk) | |
| self._evict(chunk.t1_ms - self.retention_ms) | |
| def _evict(self, threshold_ms: int): | |
| while self.q and self.q[0].t1_ms < threshold_ms: | |
| self.q.popleft() | |
| def backlog_ms(self) -> int: | |
| if not self.q: | |
| return 0 | |
| return self.q[-1].t1_ms - self.q[0].t0_ms | |
| def get_range(self, start_ms: int, end_ms: int) -> List[AudioChunk]: | |
| return [c for c in self.q if not (c.t1_ms <= start_ms or c.t0_ms >= end_ms)] | |
| def window(self) -> Dict[str, int]: | |
| if not self.q: | |
| return {"head_ms": 0, "tail_ms": 0} | |
| return {"head_ms": self.q[0].t0_ms, "tail_ms": self.q[-1].t1_ms} | |
| app = FastAPI(title="SyncMaster WS Server") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| sessions: Dict[str, SessionBuffer] = {} | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.clients: Dict[str, List[WebSocket]] = defaultdict(list) | |
| async def connect(self, session_id: str, websocket: WebSocket): | |
| await websocket.accept() | |
| self.clients[session_id].append(websocket) | |
| def disconnect(self, session_id: str, websocket: WebSocket): | |
| if websocket in self.clients.get(session_id, []): | |
| self.clients[session_id].remove(websocket) | |
| if not self.clients.get(session_id): | |
| self.clients.pop(session_id, None) | |
| async def send_json(self, session_id: str, message: dict): | |
| for ws in list(self.clients.get(session_id, [])): | |
| try: | |
| await ws.send_json(message) | |
| except Exception: | |
| # drop broken connections | |
| self.disconnect(session_id, ws) | |
| manager = ConnectionManager() | |
| async def ws_stream(websocket: WebSocket, session_id: str): | |
| await manager.connect(session_id, websocket) | |
| try: | |
| while True: | |
| msg = await websocket.receive_json() | |
| mtype = msg.get("type") | |
| if mtype == "ping": | |
| await websocket.send_json({"type": "pong", "ts_ms": int(time.time() * 1000)}) | |
| continue | |
| if mtype == "audio.chunk": | |
| # trust path param; payload session_id optional | |
| seq = int(msg.get("seq", 0)) | |
| t0_ms = int(msg.get("t0_ms", 0)) | |
| t1_ms = int(msg.get("t1_ms", 0)) | |
| mime = msg.get("mime", "audio/webm;codecs=opus") | |
| b64 = msg.get("b64", "") | |
| try: | |
| data = base64.b64decode(b64) if b64 else b"" | |
| except Exception: | |
| data = b"" | |
| buf = sessions.setdefault(session_id, SessionBuffer()) | |
| buf.push(AudioChunk(t0_ms=t0_ms, t1_ms=t1_ms, mime=mime, data=data)) | |
| await websocket.send_json( | |
| { | |
| "type": "audio.ack", | |
| "session_id": session_id, | |
| "seq": seq, | |
| "backlog_ms": buf.backlog_ms(), | |
| } | |
| ) | |
| continue | |
| await websocket.send_json({"type": "error", "message": f"unknown type: {mtype}"}) | |
| except WebSocketDisconnect: | |
| manager.disconnect(session_id, websocket) | |
| return | |
| async def health(): | |
| return {"status": "ok"} | |
| async def buffer_window(session_id: str): | |
| buf = sessions.get(session_id) | |
| if not buf: | |
| return {"head_ms": 0, "tail_ms": 0, "backlog_ms": 0} | |
| w = buf.window() | |
| return {**w, "backlog_ms": buf.backlog_ms()} | |
| class TranscribeRequest(BaseModel): | |
| session_id: Optional[str] = None | |
| slice_id: str | |
| start_ms: Optional[int] = None | |
| end_ms: Optional[int] = None | |
| requested_tier: str = "A" # A or B | |
| offset_ms: Optional[int] = None # optional: last N ms if start/end not provided | |
| jobs: Dict[str, Dict] = {} | |
| def _make_stub_result(session_id: str, slice_id: str, start_ms: int, end_ms: int, tier: str) -> Dict: | |
| dur = max(0, end_ms - start_ms) | |
| # Stub transcript | |
| text = f"Stub transcript for session {session_id} from {start_ms} to {end_ms}." | |
| # Create a couple of segments and words deterministically | |
| segs = [ | |
| {"start_ms": start_ms, "end_ms": min(end_ms, start_ms + 700), "text": "Hello", "confidence": 0.91}, | |
| {"start_ms": min(end_ms, start_ms + 700), "end_ms": end_ms, "text": "world", "confidence": 0.88}, | |
| ] | |
| words = [ | |
| {"start_ms": start_ms + 10, "end_ms": start_ms + 120, "word": "lecture", "confidence": 0.88}, | |
| {"start_ms": start_ms + 130, "end_ms": start_ms + 220, "word": "assistant", "confidence": 0.86}, | |
| ] | |
| return { | |
| "slice_id": slice_id, | |
| "session_id": session_id, | |
| "start_ms": start_ms, | |
| "end_ms": end_ms, | |
| "duration_ms": dur, | |
| "transcript": text, | |
| "segments": segs, | |
| "words": words, | |
| "status_text": f"✅ Transcript ready — {dur//1000}s", | |
| "notes": "stub", | |
| "quality_tier": tier, | |
| } | |
| async def _worker_run(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| return | |
| job["status"] = "processing" | |
| req: TranscribeRequest = job["req"] | |
| session_id = req.session_id or _pick_single_session_id() | |
| if not session_id: | |
| job["status"] = "error" | |
| job["error"] = "no session available" | |
| return | |
| buf = sessions.get(session_id) | |
| if not buf: | |
| job["status"] = "error" | |
| job["error"] = "session buffer missing" | |
| return | |
| # Determine range | |
| if req.start_ms is None or req.end_ms is None: | |
| w = buf.window() | |
| tail = w["tail_ms"] | |
| off = int(req.offset_ms or 30000) | |
| start_ms = max(w["head_ms"], tail - off) | |
| end_ms = tail | |
| else: | |
| start_ms = int(req.start_ms) | |
| end_ms = int(req.end_ms) | |
| # Simulate progress | |
| await manager.send_json(session_id, {"type": "transcribe.accepted", "slice_id": req.slice_id, "queue_pos": 0}) | |
| await asyncio.sleep(0.1) | |
| await manager.send_json(session_id, {"type": "transcribe.progress", "slice_id": req.slice_id, "pct": 30}) | |
| await asyncio.sleep(0.1) | |
| await manager.send_json(session_id, {"type": "transcribe.progress", "slice_id": req.slice_id, "pct": 70}) | |
| # Build stub result (replace with actual ASR integration) | |
| result = _make_stub_result(session_id, req.slice_id, start_ms, end_ms, req.requested_tier) | |
| job["result"] = result | |
| job["status"] = "done" | |
| await manager.send_json(session_id, {"type": "transcribe.result", **result}) | |
| def _pick_single_session_id() -> Optional[str]: | |
| if len(sessions) == 1: | |
| return next(iter(sessions.keys())) | |
| return None | |
| async def transcribe_slice(req: TranscribeRequest = Body(...)): | |
| # Fill default session if not provided and only one exists | |
| if not req.session_id: | |
| sid = _pick_single_session_id() | |
| if sid: | |
| req.session_id = sid | |
| job_id = str(uuid.uuid4()) | |
| jobs[job_id] = {"status": "queued", "req": req} | |
| asyncio.create_task(_worker_run(job_id)) | |
| return {"job_id": job_id, "eta_ms": 1500} | |
| async def get_job(job_id: str): | |
| job = jobs.get(job_id) | |
| if not job: | |
| return {"status": "not_found"} | |
| resp = {"status": job.get("status")} | |
| if job.get("status") == "done": | |
| resp["result"] = job.get("result") | |
| if job.get("status") == "error": | |
| resp["error"] = job.get("error") | |
| return resp | |