Spaces:
No application file
No application file
| import os | |
| import time | |
| import jwt | |
| import logging | |
| import asyncio | |
| import hashlib | |
| import tempfile | |
| import subprocess | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Dict, List, Any | |
| import aiohttp | |
| import librosa | |
| import uvicorn | |
| from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.middleware.cors import CORSMiddleware | |
| # --- 1. CONFIGURATION --- | |
| class GlobalConfig: | |
| # Set these in HF Space Secrets | |
| HF_TOKEN = os.getenv("HF_TOKEN", "") | |
| API_SECRET = os.getenv("API_SECRET_KEY", "default_secret_change_me_in_production") | |
| MODELS = { | |
| "emotion2vec": {"url": "https://api-inference.huggingface.co/models/emotion2vec/emotion2vec_plus_base", "w": 0.50}, | |
| "meralion": {"url": "https://api-inference.huggingface.co/models/MERaLiON/MERaLiON-SER-v1", "w": 0.25}, | |
| "wav2vec2": {"url": "https://api-inference.huggingface.co/models/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", "w": 0.15}, | |
| "hubert": {"url": "https://api-inference.huggingface.co/models/superb/hubert-large-superb-er", "w": 0.07}, | |
| "gigam": {"url": "https://api-inference.huggingface.co/models/salute-developers/GigaAM-emo", "w": 0.03} | |
| } | |
| # Standardized internal labels | |
| MAPPING = { | |
| "angry": ["ang", "fear"], # Merging high-arousal negative | |
| "happy": ["hap", "joy", "surp"], | |
| "sad": ["sad"], | |
| "neutral": ["neu", "calm"] | |
| } | |
| cfg = GlobalConfig() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("EmotionAPI") | |
| # --- 2. AUTHENTICATION --- | |
| security = HTTPBearer() | |
| def create_access_token(data: dict): | |
| to_encode = data.copy() | |
| expire = datetime.now(timezone.utc) + timedelta(minutes=60) | |
| to_encode.update({"exp": expire}) | |
| return jwt.encode(to_encode, cfg.API_SECRET, algorithm="HS256") | |
| async def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| try: | |
| payload = jwt.decode(credentials.credentials, cfg.API_SECRET, algorithms=["HS256"]) | |
| return payload | |
| except Exception: | |
| raise HTTPException(status_code=401, detail="Invalid/Expired Token") | |
| # --- 3. CORE LOGIC --- | |
| async def process_audio(file: UploadFile): | |
| """Handles format conversion and validation""" | |
| suffix = f".{file.filename.split('.')[-1]}" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_in: | |
| content = await file.read() | |
| tmp_in.write(content) | |
| input_path = tmp_in.name | |
| output_path = input_path + ".wav" | |
| try: | |
| # Standardize to 16kHz Mono WAV | |
| proc = subprocess.run( | |
| ["ffmpeg", "-i", input_path, "-ar", "16000", "-ac", "1", "-y", output_path], | |
| capture_output=True, text=True | |
| ) | |
| if proc.returncode != 0: | |
| raise Exception(f"FFmpeg error: {proc.stderr}") | |
| with open(output_path, "rb") as f: | |
| audio_bytes = f.read() | |
| duration = librosa.get_duration(path=output_path) | |
| return audio_bytes, duration | |
| finally: | |
| for p in [input_path, output_path]: | |
| if os.path.exists(p): os.unlink(p) | |
| async def query_hf(session, name, url, data): | |
| """Individual model call with retry for 'loading' status""" | |
| headers = {"Authorization": f"Bearer {cfg.HF_TOKEN}"} | |
| for _ in range(3): # Simple retry if model is loading | |
| async with session.post(url, headers=headers, data=data) as resp: | |
| res = await resp.json() | |
| if resp.status == 200: | |
| return res | |
| elif resp.status == 503: # Model loading | |
| await asyncio.sleep(5) | |
| continue | |
| return None | |
| def ensemble_logic(responses: dict): | |
| """Weighted average of results""" | |
| final_scores = defaultdict(float) | |
| for name, preds in responses.items(): | |
| if not isinstance(preds, list): continue | |
| weight = cfg.MODELS[name]["w"] | |
| for p in preds: | |
| label = p['label'].lower() | |
| # Map labels to our standard set | |
| mapped = "neutral" | |
| for std, keywords in cfg.MAPPING.items(): | |
| if any(k in label for k in keywords): | |
| mapped = std | |
| break | |
| final_scores[mapped] += p['score'] * weight | |
| sorted_res = sorted(final_scores.items(), key=lambda x: x[1], reverse=True) | |
| return { | |
| "primary": sorted_res[0][0] if sorted_res else "unknown", | |
| "confidence": round(sorted_res[0][1], 3) if sorted_res else 0, | |
| "distribution": {k: round(v, 3) for k, v in sorted_res} | |
| } | |
| # --- 4. API ENDPOINTS --- | |
| app = FastAPI(title="Emotion Ensemble API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def health(): | |
| return {"status": "online", "hf_configured": bool(cfg.HF_TOKEN)} | |
| def get_token(user: str = "hf_user"): | |
| return {"token": create_access_token({"sub": user})} | |
| async def analyze(file: UploadFile = File(...), auth=Depends(verify_jwt)): | |
| start_time = time.time() | |
| # 1. Process Audio | |
| try: | |
| audio_bytes, duration = await process_audio(file) | |
| except Exception as e: | |
| raise HTTPException(400, f"Audio processing failed: {str(e)}") | |
| # 2. Run Parallel Inference | |
| async with aiohttp.ClientSession() as session: | |
| tasks = {name: query_hf(session, name, m["url"], audio_bytes) | |
| for name, m in cfg.MODELS.items()} | |
| results = await asyncio.gather(*tasks.values()) | |
| raw_responses = dict(zip(tasks.keys(), results)) | |
| # 3. Ensemble & Format | |
| successful_models = {k: v for k, v in raw_responses.items() if v is not None} | |
| if not successful_models: | |
| raise HTTPException(503, "All upstream models failed.") | |
| analysis = ensemble_logic(successful_models) | |
| return { | |
| "emotion": analysis["primary"], | |
| "confidence": analysis["confidence"], | |
| "scores": analysis["distribution"], | |
| "meta": { | |
| "duration_sec": round(duration, 2), | |
| "latency_sec": round(time.time() - start_time, 2), | |
| "models_responding": len(successful_models) | |
| } | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860))) |