Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import logging | |
| import tempfile | |
| import threading | |
| import subprocess | |
| import numpy as np | |
| import soundfile as sf | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| from faster_whisper import WhisperModel | |
| # ───────────────────────────────────────────── | |
| # Logging | |
| # ───────────────────────────────────────────── | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ───────────────────────────────────────────── | |
| # App | |
| # ───────────────────────────────────────────── | |
| app = FastAPI(title="Klas English Transcription API") | |
| # ───────────────────────────────────────────── | |
| # Config | |
| # ───────────────────────────────────────────── | |
| ASR_MODEL_SIZE = os.getenv("ASR_MODEL_SIZE", "small.en") | |
| DEVICE = os.getenv("DEVICE", "cpu") # HF Spaces is CPU by default | |
| COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "int8") # int8 works well on CPU | |
| # ───────────────────────────────────────────── | |
| # Globals | |
| # ───────────────────────────────────────────── | |
| asr_model = None | |
| model_loaded = False | |
| # ───────────────────────────────────────────── | |
| # Model loading — background thread so the | |
| # container passes HF's health check quickly | |
| # ───────────────────────────────────────────── | |
| def _load_model(): | |
| global asr_model, model_loaded | |
| try: | |
| logger.info(f"Loading faster-whisper ({ASR_MODEL_SIZE}) on {DEVICE}/{COMPUTE_TYPE}") | |
| asr_model = WhisperModel( | |
| ASR_MODEL_SIZE, | |
| device=DEVICE, | |
| compute_type=COMPUTE_TYPE, | |
| ) | |
| model_loaded = True | |
| logger.info("Model ready ✅") | |
| except Exception as e: | |
| logger.error(f"Model load failed: {e}", exc_info=True) | |
| async def startup(): | |
| threading.Thread(target=_load_model, daemon=True).start() | |
| # ───────────────────────────────────────────── | |
| # Health | |
| # ───────────────────────────────────────────── | |
| def ping(): | |
| return {"status": "ready" if model_loaded else "initializing"} | |
| # ───────────────────────────────────────────── | |
| # Audio helper | |
| # ───────────────────────────────────────────── | |
| def _load_audio(raw: bytes) -> np.ndarray: | |
| """Read audio bytes → float32 mono 16 kHz numpy array.""" | |
| try: | |
| arr, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False) | |
| except Exception: | |
| # Fallback: treat as raw float32 PCM at 16 kHz | |
| arr = np.frombuffer(raw, dtype=np.float32) | |
| sr = 16000 | |
| # Stereo → mono | |
| if arr.ndim > 1: | |
| arr = arr.mean(axis=1) | |
| # Resample to 16 kHz if needed | |
| if sr != 16000: | |
| logger.info(f"Resampling {sr} Hz → 16000 Hz") | |
| tmp_in = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| tmp_out = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
| try: | |
| sf.write(tmp_in.name, arr, sr) | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-i", tmp_in.name, | |
| "-ar", "16000", "-ac", "1", tmp_out.name], | |
| check=True, capture_output=True, | |
| ) | |
| arr, _ = sf.read(tmp_out.name, dtype="float32") | |
| finally: | |
| os.unlink(tmp_in.name) | |
| os.unlink(tmp_out.name) | |
| # Normalize | |
| peak = max(abs(arr.max()), abs(arr.min()), 1e-9) | |
| if peak > 1.0: | |
| arr = arr / peak | |
| return arr | |
| # ───────────────────────────────────────────── | |
| # Transcription helper | |
| # ───────────────────────────────────────────── | |
| def _transcribe(audio_arr: np.ndarray) -> str: | |
| segments, info = asr_model.transcribe( | |
| audio_arr, | |
| language="en", | |
| beam_size=5, | |
| vad_filter=True, | |
| word_timestamps=False, | |
| ) | |
| text = " ".join(seg.text for seg in segments).strip() | |
| if not text: | |
| return text | |
| # Fix ALL-CAPS transcriptions (some audio conditions trigger this) | |
| if text == text.upper(): | |
| text = text.lower() | |
| return text[0].upper() + text[1:] if len(text) > 1 else text.upper() | |
| # ───────────────────────────────────────────── | |
| # POST /transcribe | |
| # Accepts : multipart/form-data { audio: <file> } | |
| # Returns : JSON { transcript, duration_sec, language } | |
| # ───────────────────────────────────────────── | |
| async def transcribe( | |
| audio: UploadFile = File(...), | |
| language: str = Form("en"), # kept for future multi-lang support | |
| ): | |
| if not model_loaded: | |
| raise HTTPException(status_code=503, detail="Model still loading, try again shortly") | |
| raw = await audio.read() | |
| if not raw: | |
| raise HTTPException(status_code=400, detail="Uploaded file is empty") | |
| try: | |
| arr = _load_audio(raw) | |
| except Exception as e: | |
| logger.error(f"Audio load error: {e}", exc_info=True) | |
| raise HTTPException(status_code=422, detail=f"Could not read audio: {e}") | |
| duration_sec = round(len(arr) / 16000, 3) | |
| try: | |
| transcript = _transcribe(arr) | |
| except Exception as e: | |
| logger.error(f"Transcription error: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if not transcript: | |
| raise HTTPException(status_code=422, detail="No speech detected in audio") | |
| return JSONResponse({ | |
| "transcript": transcript, | |
| "duration_sec": duration_sec, | |
| "language": "en", | |
| }) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |