PsalmsJava's picture
Updated Again
84f84c3
import os
import time
import jwt
import logging
import asyncio
import hashlib
import tempfile
import subprocess
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Any
import aiohttp
import librosa
import uvicorn
from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
# --- 1. CONFIGURATION ---
class GlobalConfig:
# Set these in HF Space Secrets
HF_TOKEN = os.getenv("HF_TOKEN", "")
API_SECRET = os.getenv("API_SECRET_KEY", "default_secret_change_me_in_production")
MODELS = {
"emotion2vec": {"url": "https://api-inference.huggingface.co/models/emotion2vec/emotion2vec_plus_base", "w": 0.50},
"meralion": {"url": "https://api-inference.huggingface.co/models/MERaLiON/MERaLiON-SER-v1", "w": 0.25},
"wav2vec2": {"url": "https://api-inference.huggingface.co/models/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", "w": 0.15},
"hubert": {"url": "https://api-inference.huggingface.co/models/superb/hubert-large-superb-er", "w": 0.07},
"gigam": {"url": "https://api-inference.huggingface.co/models/salute-developers/GigaAM-emo", "w": 0.03}
}
# Standardized internal labels
MAPPING = {
"angry": ["ang", "fear"], # Merging high-arousal negative
"happy": ["hap", "joy", "surp"],
"sad": ["sad"],
"neutral": ["neu", "calm"]
}
cfg = GlobalConfig()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("EmotionAPI")
# --- 2. AUTHENTICATION ---
security = HTTPBearer()
def create_access_token(data: dict):
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=60)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, cfg.API_SECRET, algorithm="HS256")
async def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)):
try:
payload = jwt.decode(credentials.credentials, cfg.API_SECRET, algorithms=["HS256"])
return payload
except Exception:
raise HTTPException(status_code=401, detail="Invalid/Expired Token")
# --- 3. CORE LOGIC ---
async def process_audio(file: UploadFile):
"""Handles format conversion and validation"""
suffix = f".{file.filename.split('.')[-1]}"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_in:
content = await file.read()
tmp_in.write(content)
input_path = tmp_in.name
output_path = input_path + ".wav"
try:
# Standardize to 16kHz Mono WAV
proc = subprocess.run(
["ffmpeg", "-i", input_path, "-ar", "16000", "-ac", "1", "-y", output_path],
capture_output=True, text=True
)
if proc.returncode != 0:
raise Exception(f"FFmpeg error: {proc.stderr}")
with open(output_path, "rb") as f:
audio_bytes = f.read()
duration = librosa.get_duration(path=output_path)
return audio_bytes, duration
finally:
for p in [input_path, output_path]:
if os.path.exists(p): os.unlink(p)
async def query_hf(session, name, url, data):
"""Individual model call with retry for 'loading' status"""
headers = {"Authorization": f"Bearer {cfg.HF_TOKEN}"}
for _ in range(3): # Simple retry if model is loading
async with session.post(url, headers=headers, data=data) as resp:
res = await resp.json()
if resp.status == 200:
return res
elif resp.status == 503: # Model loading
await asyncio.sleep(5)
continue
return None
def ensemble_logic(responses: dict):
"""Weighted average of results"""
final_scores = defaultdict(float)
for name, preds in responses.items():
if not isinstance(preds, list): continue
weight = cfg.MODELS[name]["w"]
for p in preds:
label = p['label'].lower()
# Map labels to our standard set
mapped = "neutral"
for std, keywords in cfg.MAPPING.items():
if any(k in label for k in keywords):
mapped = std
break
final_scores[mapped] += p['score'] * weight
sorted_res = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)
return {
"primary": sorted_res[0][0] if sorted_res else "unknown",
"confidence": round(sorted_res[0][1], 3) if sorted_res else 0,
"distribution": {k: round(v, 3) for k, v in sorted_res}
}
# --- 4. API ENDPOINTS ---
app = FastAPI(title="Emotion Ensemble API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health():
return {"status": "online", "hf_configured": bool(cfg.HF_TOKEN)}
@app.get("/token")
def get_token(user: str = "hf_user"):
return {"token": create_access_token({"sub": user})}
@app.post("/analyze")
async def analyze(file: UploadFile = File(...), auth=Depends(verify_jwt)):
start_time = time.time()
# 1. Process Audio
try:
audio_bytes, duration = await process_audio(file)
except Exception as e:
raise HTTPException(400, f"Audio processing failed: {str(e)}")
# 2. Run Parallel Inference
async with aiohttp.ClientSession() as session:
tasks = {name: query_hf(session, name, m["url"], audio_bytes)
for name, m in cfg.MODELS.items()}
results = await asyncio.gather(*tasks.values())
raw_responses = dict(zip(tasks.keys(), results))
# 3. Ensemble & Format
successful_models = {k: v for k, v in raw_responses.items() if v is not None}
if not successful_models:
raise HTTPException(503, "All upstream models failed.")
analysis = ensemble_logic(successful_models)
return {
"emotion": analysis["primary"],
"confidence": analysis["confidence"],
"scores": analysis["distribution"],
"meta": {
"duration_sec": round(duration, 2),
"latency_sec": round(time.time() - start_time, 2),
"models_responding": len(successful_models)
}
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))