offiongbassey's picture
Create app.py
6ad1b35 verified
import os
import io
import logging
import tempfile
import threading
import subprocess
import numpy as np
import soundfile as sf
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import JSONResponse
from faster_whisper import WhisperModel
# ─────────────────────────────────────────────
# Logging
# ─────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s"
)
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────
# App
# ─────────────────────────────────────────────
app = FastAPI(title="Klas English Transcription API")
# ─────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────
ASR_MODEL_SIZE = os.getenv("ASR_MODEL_SIZE", "small.en")
DEVICE = os.getenv("DEVICE", "cpu") # HF Spaces is CPU by default
COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "int8") # int8 works well on CPU
# ─────────────────────────────────────────────
# Globals
# ─────────────────────────────────────────────
asr_model = None
model_loaded = False
# ─────────────────────────────────────────────
# Model loading — background thread so the
# container passes HF's health check quickly
# ─────────────────────────────────────────────
def _load_model():
global asr_model, model_loaded
try:
logger.info(f"Loading faster-whisper ({ASR_MODEL_SIZE}) on {DEVICE}/{COMPUTE_TYPE}")
asr_model = WhisperModel(
ASR_MODEL_SIZE,
device=DEVICE,
compute_type=COMPUTE_TYPE,
)
model_loaded = True
logger.info("Model ready ✅")
except Exception as e:
logger.error(f"Model load failed: {e}", exc_info=True)
@app.on_event("startup")
async def startup():
threading.Thread(target=_load_model, daemon=True).start()
# ─────────────────────────────────────────────
# Health
# ─────────────────────────────────────────────
@app.get("/ping")
def ping():
return {"status": "ready" if model_loaded else "initializing"}
# ─────────────────────────────────────────────
# Audio helper
# ─────────────────────────────────────────────
def _load_audio(raw: bytes) -> np.ndarray:
"""Read audio bytes → float32 mono 16 kHz numpy array."""
try:
arr, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
except Exception:
# Fallback: treat as raw float32 PCM at 16 kHz
arr = np.frombuffer(raw, dtype=np.float32)
sr = 16000
# Stereo → mono
if arr.ndim > 1:
arr = arr.mean(axis=1)
# Resample to 16 kHz if needed
if sr != 16000:
logger.info(f"Resampling {sr} Hz → 16000 Hz")
tmp_in = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_out = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
try:
sf.write(tmp_in.name, arr, sr)
subprocess.run(
["ffmpeg", "-y", "-i", tmp_in.name,
"-ar", "16000", "-ac", "1", tmp_out.name],
check=True, capture_output=True,
)
arr, _ = sf.read(tmp_out.name, dtype="float32")
finally:
os.unlink(tmp_in.name)
os.unlink(tmp_out.name)
# Normalize
peak = max(abs(arr.max()), abs(arr.min()), 1e-9)
if peak > 1.0:
arr = arr / peak
return arr
# ─────────────────────────────────────────────
# Transcription helper
# ─────────────────────────────────────────────
def _transcribe(audio_arr: np.ndarray) -> str:
segments, info = asr_model.transcribe(
audio_arr,
language="en",
beam_size=5,
vad_filter=True,
word_timestamps=False,
)
text = " ".join(seg.text for seg in segments).strip()
if not text:
return text
# Fix ALL-CAPS transcriptions (some audio conditions trigger this)
if text == text.upper():
text = text.lower()
return text[0].upper() + text[1:] if len(text) > 1 else text.upper()
# ─────────────────────────────────────────────
# POST /transcribe
# Accepts : multipart/form-data { audio: <file> }
# Returns : JSON { transcript, duration_sec, language }
# ─────────────────────────────────────────────
@app.post("/transcribe")
async def transcribe(
audio: UploadFile = File(...),
language: str = Form("en"), # kept for future multi-lang support
):
if not model_loaded:
raise HTTPException(status_code=503, detail="Model still loading, try again shortly")
raw = await audio.read()
if not raw:
raise HTTPException(status_code=400, detail="Uploaded file is empty")
try:
arr = _load_audio(raw)
except Exception as e:
logger.error(f"Audio load error: {e}", exc_info=True)
raise HTTPException(status_code=422, detail=f"Could not read audio: {e}")
duration_sec = round(len(arr) / 16000, 3)
try:
transcript = _transcribe(arr)
except Exception as e:
logger.error(f"Transcription error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if not transcript:
raise HTTPException(status_code=422, detail="No speech detected in audio")
return JSONResponse({
"transcript": transcript,
"duration_sec": duration_sec,
"language": "en",
})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)