PsalmsJava commited on
Commit
84f84c3
·
1 Parent(s): 8ef2cca

Updated Again

Browse files
Files changed (2) hide show
  1. DockerFile +1 -1
  2. app.py +146 -168
DockerFile CHANGED
@@ -31,7 +31,7 @@ ENV PORT=7860
31
 
32
  # Health check
33
  HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
34
- CMD curl -f http://localhost:${PORT}/health || exit 1
35
 
36
  # Run the application
37
  CMD uvicorn app:app --host 0.0.0.0 --port ${PORT}
 
31
 
32
  # Health check
33
  HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
34
+ CMD curl apt-get -f http://localhost:${PORT}/health || exit 1
35
 
36
  # Run the application
37
  CMD uvicorn app:app --host 0.0.0.0 --port ${PORT}
app.py CHANGED
@@ -1,201 +1,179 @@
1
  import os
2
  import time
3
  import jwt
 
 
4
  import hashlib
5
  import tempfile
6
  import subprocess
7
- import logging
8
- import asyncio
9
  from datetime import datetime, timedelta, timezone
10
- from typing import Dict, List, Any, Optional
11
- from collections import defaultdict
12
- from contextlib import asynccontextmanager
13
 
14
  import aiohttp
15
- import numpy as np
16
  import librosa
 
17
  from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status
18
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
19
  from fastapi.middleware.cors import CORSMiddleware
20
- from fastapi.openapi.docs import get_swagger_ui_html
21
- from fastapi.openapi.utils import get_openapi
22
- from pydantic import BaseModel
23
 
24
- # ==================== CONFIGURATION ====================
25
- class Config:
 
26
  HF_TOKEN = os.getenv("HF_TOKEN", "")
27
- # Default secret for dev; HF Spaces should set this in Settings > Variables
28
- API_SECRET_KEY = os.getenv("API_SECRET_KEY", "hf_space_default_secret_123")
29
- ALGORITHM = "HS256"
30
- ACCESS_TOKEN_EXPIRE_MINUTES = 30
31
 
32
  MODELS = {
33
- "emotion2vec_plus": {"url": "https://api-inference.huggingface.co/models/emotion2vec/emotion2vec_plus_base", "weight": 0.50, "timeout": 30, "description": "Foundation SER model"},
34
- "meralion_ser": {"url": "https://api-inference.huggingface.co/models/MERaLiON/MERaLiON-SER-v1", "weight": 0.25, "timeout": 30, "description": "English/SEA optimized"},
35
- "wav2vec2_english": {"url": "https://api-inference.huggingface.co/models/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", "weight": 0.15, "timeout": 25, "description": "English fine-tuned"},
36
- "hubert_er": {"url": "https://api-inference.huggingface.co/models/superb/hubert-large-superb-er", "weight": 0.07, "timeout": 25, "description": "Acoustic specialist"},
37
- "gigam_emo": {"url": "https://api-inference.huggingface.co/models/salute-developers/GigaAM-emo", "weight": 0.03, "timeout": 20, "description": "Acoustic pattern expert"}
38
  }
39
 
40
- MAX_FILE_SIZE_MB = 10
41
- SUPPORTED_FORMATS = ["wav", "mp3", "m4a", "ogg", "flac", "aac"]
42
- TARGET_SAMPLE_RATE = 16000
43
- MAX_DURATION_SECONDS = 30
44
- EMOTION_MAPPING = {
45
- "angry": ["angry", "ang", "anger"],
46
- "happy": ["happy", "hap", "happiness", "joy"],
47
- "sad": ["sad", "sadness"],
48
- "fear": ["fear", "fearful"],
49
- "surprise": ["surprise", "surprised"],
50
- "disgust": ["disgust", "disgusted"],
51
- "neutral": ["neutral", "neu"]
52
  }
53
 
54
- config = Config()
55
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
56
- logger = logging.getLogger(__name__)
57
 
58
- # ==================== AUTH & UTILS ====================
59
  security = HTTPBearer()
60
 
61
- class AuthHandler:
62
- @staticmethod
63
- def create_token(client_id: str = "api_client") -> str:
64
- expire = datetime.now(timezone.utc) + timedelta(minutes=config.ACCESS_TOKEN_EXPIRE_MINUTES)
65
- payload = {"sub": client_id, "exp": expire, "iat": datetime.now(timezone.utc), "type": "access"}
66
- return jwt.encode(payload, config.API_SECRET_KEY, algorithm=config.ALGORITHM)
67
-
68
- @staticmethod
69
- def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> str:
70
- token = credentials.credentials
71
- try:
72
- payload = jwt.decode(token, config.API_SECRET_KEY, algorithms=[config.ALGORITHM])
73
- return payload.get("sub", "anonymous")
74
- except Exception:
75
- raise HTTPException(status_code=401, detail="Invalid or expired token")
76
-
77
- # ==================== CORE LOGIC ====================
78
- class AudioProcessor:
79
- @staticmethod
80
- async def validate_and_process(file: UploadFile) -> tuple:
81
- contents = await file.read()
82
- if len(contents) / (1024 * 1024) > config.MAX_FILE_SIZE_MB:
83
- raise HTTPException(413, "File too large")
84
-
85
- ext = file.filename.split('.')[-1].lower()
86
- with tempfile.NamedTemporaryFile(delete=False, suffix=f".{ext}") as f_in:
87
- f_in.write(contents)
88
- input_path = f_in.name
89
-
90
- output_path = input_path + ".wav"
91
- try:
92
- cmd = ["ffmpeg", "-i", input_path, "-ar", str(config.TARGET_SAMPLE_RATE), "-ac", "1", "-y", output_path]
93
- subprocess.run(cmd, capture_output=True, check=True, timeout=30)
94
-
95
- y, sr = librosa.load(output_path, sr=config.TARGET_SAMPLE_RATE)
96
- duration = len(y) / sr
97
- if duration > config.MAX_DURATION_SECONDS:
98
- raise HTTPException(400, "Audio too long")
99
-
100
- with open(output_path, "rb") as f:
101
- return f.read(), {"duration": round(duration, 2), "format": ext}
102
- finally:
103
- for p in [input_path, output_path]:
104
- if os.path.exists(p): os.unlink(p)
105
-
106
- class EmotionEnsemble:
107
- def __init__(self):
108
- self.models = config.MODELS
109
-
110
- async def predict(self, audio_bytes: bytes) -> Dict[str, Any]:
111
- if not config.HF_TOKEN:
112
- raise HTTPException(503, "HF_TOKEN missing")
113
-
114
- headers = {"Authorization": f"Bearer {config.HF_TOKEN}"}
115
- async with aiohttp.ClientSession() as session:
116
- tasks = [self._query(session, name, m_cfg, audio_bytes, headers) for name, m_cfg in self.models.items()]
117
- results = await asyncio.gather(*tasks)
118
-
119
- model_outputs = {name: res for name, res in zip(self.models.keys(), results) if res}
120
- if not model_outputs:
121
- raise HTTPException(503, "All models failed to respond")
122
-
123
- return self._fuse(model_outputs)
124
-
125
- async def _query(self, session, name, cfg, data, headers):
126
- try:
127
- async with session.post(cfg["url"], headers=headers, data=data, timeout=cfg["timeout"]) as resp:
128
- if resp.status == 200: return await resp.json()
129
- except: return None
130
-
131
- def _fuse(self, model_outputs):
132
- scores = defaultdict(float)
133
- for name, preds in model_outputs.items():
134
- w = self.models[name]["weight"]
135
- for p in preds:
136
- label = self._map(p['label'])
137
- scores[label] += p['score'] * w
138
 
139
- sorted_scores = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
140
- primary = list(sorted_scores.items())[0]
141
- return {"primary_emotion": primary[0], "confidence": round(primary[1], 3), "all_emotions": sorted_scores}
142
-
143
- def _map(self, label: str) -> str:
144
- label = label.lower()
145
- for std, vars in config.EMOTION_MAPPING.items():
146
- if any(v in label for v in vars): return std
147
- return "neutral"
148
-
149
- # ==================== APP SETUP ====================
150
- @asynccontextmanager
151
- async def lifespan(app: FastAPI):
152
- logger.info("🚀 API Starting Up...")
153
- yield
154
- logger.info("🛑 API Shutting Down...")
155
-
156
- app = FastAPI(title="Emotion API", lifespan=lifespan, docs_url=None)
157
- auth_handler = AuthHandler()
158
- audio_proc = AudioProcessor()
159
- ensemble = EmotionEnsemble()
160
- cache = {}
161
-
162
- @app.get("/")
163
- async def root(): return {"message": "Emotion API Active", "docs": "/docs"}
164
-
165
- @app.get("/auth/token")
166
- async def get_token(client_id: str = "api_client"):
167
- return {"access_token": auth_handler.create_token(client_id)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  @app.post("/analyze")
170
- async def analyze(file: UploadFile = File(...), user: str = Depends(auth_handler.verify_token)):
171
- content = await file.read()
172
- await file.seek(0) # Reset for the processor
173
- ckey = hashlib.md5(content).hexdigest()
174
-
175
- if ckey in cache: return cache[ckey]
176
 
177
- audio_bytes, info = await audio_proc.validate_and_process(file)
178
- res = await ensemble.predict(audio_bytes)
179
- res.update({"audio_info": info, "user": user})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- if len(cache) < 100: cache[ckey] = res
182
- return res
183
-
184
- @app.get("/docs", include_in_schema=False)
185
- async def custom_docs():
186
- return get_swagger_ui_html(openapi_url="/openapi.json", title="API Docs")
187
-
188
- @app.get("/openapi.json", include_in_schema=False)
189
- async def get_open_api_endpoint():
190
- if app.openapi_schema: return app.openapi_schema
191
- schema = get_openapi(title="Emotion Ensemble API", version="1.0.0", routes=app.routes)
192
- schema["components"]["securitySchemes"] = {
193
- "bearerAuth": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}
194
  }
195
- schema["security"] = [{"bearerAuth": []}]
196
- app.openapi_schema = schema
197
- return schema
198
 
199
  if __name__ == "__main__":
200
- import uvicorn
201
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
 
1
  import os
2
  import time
3
  import jwt
4
+ import logging
5
+ import asyncio
6
  import hashlib
7
  import tempfile
8
  import subprocess
 
 
9
  from datetime import datetime, timedelta, timezone
10
+ from typing import Dict, List, Any
 
 
11
 
12
  import aiohttp
 
13
  import librosa
14
+ import uvicorn
15
  from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, status
16
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
17
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
18
 
19
+ # --- 1. CONFIGURATION ---
20
+ class GlobalConfig:
21
+ # Set these in HF Space Secrets
22
  HF_TOKEN = os.getenv("HF_TOKEN", "")
23
+ API_SECRET = os.getenv("API_SECRET_KEY", "default_secret_change_me_in_production")
 
 
 
24
 
25
  MODELS = {
26
+ "emotion2vec": {"url": "https://api-inference.huggingface.co/models/emotion2vec/emotion2vec_plus_base", "w": 0.50},
27
+ "meralion": {"url": "https://api-inference.huggingface.co/models/MERaLiON/MERaLiON-SER-v1", "w": 0.25},
28
+ "wav2vec2": {"url": "https://api-inference.huggingface.co/models/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", "w": 0.15},
29
+ "hubert": {"url": "https://api-inference.huggingface.co/models/superb/hubert-large-superb-er", "w": 0.07},
30
+ "gigam": {"url": "https://api-inference.huggingface.co/models/salute-developers/GigaAM-emo", "w": 0.03}
31
  }
32
 
33
+ # Standardized internal labels
34
+ MAPPING = {
35
+ "angry": ["ang", "fear"], # Merging high-arousal negative
36
+ "happy": ["hap", "joy", "surp"],
37
+ "sad": ["sad"],
38
+ "neutral": ["neu", "calm"]
 
 
 
 
 
 
39
  }
40
 
41
+ cfg = GlobalConfig()
42
+ logging.basicConfig(level=logging.INFO)
43
+ logger = logging.getLogger("EmotionAPI")
44
 
45
+ # --- 2. AUTHENTICATION ---
46
  security = HTTPBearer()
47
 
48
+ def create_access_token(data: dict):
49
+ to_encode = data.copy()
50
+ expire = datetime.now(timezone.utc) + timedelta(minutes=60)
51
+ to_encode.update({"exp": expire})
52
+ return jwt.encode(to_encode, cfg.API_SECRET, algorithm="HS256")
53
+
54
+ async def verify_jwt(credentials: HTTPAuthorizationCredentials = Depends(security)):
55
+ try:
56
+ payload = jwt.decode(credentials.credentials, cfg.API_SECRET, algorithms=["HS256"])
57
+ return payload
58
+ except Exception:
59
+ raise HTTPException(status_code=401, detail="Invalid/Expired Token")
60
+
61
+ # --- 3. CORE LOGIC ---
62
+ async def process_audio(file: UploadFile):
63
+ """Handles format conversion and validation"""
64
+ suffix = f".{file.filename.split('.')[-1]}"
65
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_in:
66
+ content = await file.read()
67
+ tmp_in.write(content)
68
+ input_path = tmp_in.name
69
+
70
+ output_path = input_path + ".wav"
71
+ try:
72
+ # Standardize to 16kHz Mono WAV
73
+ proc = subprocess.run(
74
+ ["ffmpeg", "-i", input_path, "-ar", "16000", "-ac", "1", "-y", output_path],
75
+ capture_output=True, text=True
76
+ )
77
+ if proc.returncode != 0:
78
+ raise Exception(f"FFmpeg error: {proc.stderr}")
79
+
80
+ with open(output_path, "rb") as f:
81
+ audio_bytes = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ duration = librosa.get_duration(path=output_path)
84
+ return audio_bytes, duration
85
+ finally:
86
+ for p in [input_path, output_path]:
87
+ if os.path.exists(p): os.unlink(p)
88
+
89
+ async def query_hf(session, name, url, data):
90
+ """Individual model call with retry for 'loading' status"""
91
+ headers = {"Authorization": f"Bearer {cfg.HF_TOKEN}"}
92
+ for _ in range(3): # Simple retry if model is loading
93
+ async with session.post(url, headers=headers, data=data) as resp:
94
+ res = await resp.json()
95
+ if resp.status == 200:
96
+ return res
97
+ elif resp.status == 503: # Model loading
98
+ await asyncio.sleep(5)
99
+ continue
100
+ return None
101
+
102
+ def ensemble_logic(responses: dict):
103
+ """Weighted average of results"""
104
+ final_scores = defaultdict(float)
105
+ for name, preds in responses.items():
106
+ if not isinstance(preds, list): continue
107
+ weight = cfg.MODELS[name]["w"]
108
+ for p in preds:
109
+ label = p['label'].lower()
110
+ # Map labels to our standard set
111
+ mapped = "neutral"
112
+ for std, keywords in cfg.MAPPING.items():
113
+ if any(k in label for k in keywords):
114
+ mapped = std
115
+ break
116
+ final_scores[mapped] += p['score'] * weight
117
+
118
+ sorted_res = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)
119
+ return {
120
+ "primary": sorted_res[0][0] if sorted_res else "unknown",
121
+ "confidence": round(sorted_res[0][1], 3) if sorted_res else 0,
122
+ "distribution": {k: round(v, 3) for k, v in sorted_res}
123
+ }
124
+
125
+ # --- 4. API ENDPOINTS ---
126
+ app = FastAPI(title="Emotion Ensemble API")
127
+
128
+ app.add_middleware(
129
+ CORSMiddleware,
130
+ allow_origins=["*"],
131
+ allow_methods=["*"],
132
+ allow_headers=["*"],
133
+ )
134
+
135
+ @app.get("/health")
136
+ def health():
137
+ return {"status": "online", "hf_configured": bool(cfg.HF_TOKEN)}
138
+
139
+ @app.get("/token")
140
+ def get_token(user: str = "hf_user"):
141
+ return {"token": create_access_token({"sub": user})}
142
 
143
  @app.post("/analyze")
144
+ async def analyze(file: UploadFile = File(...), auth=Depends(verify_jwt)):
145
+ start_time = time.time()
 
 
 
 
146
 
147
+ # 1. Process Audio
148
+ try:
149
+ audio_bytes, duration = await process_audio(file)
150
+ except Exception as e:
151
+ raise HTTPException(400, f"Audio processing failed: {str(e)}")
152
+
153
+ # 2. Run Parallel Inference
154
+ async with aiohttp.ClientSession() as session:
155
+ tasks = {name: query_hf(session, name, m["url"], audio_bytes)
156
+ for name, m in cfg.MODELS.items()}
157
+ results = await asyncio.gather(*tasks.values())
158
+ raw_responses = dict(zip(tasks.keys(), results))
159
+
160
+ # 3. Ensemble & Format
161
+ successful_models = {k: v for k, v in raw_responses.items() if v is not None}
162
+ if not successful_models:
163
+ raise HTTPException(503, "All upstream models failed.")
164
+
165
+ analysis = ensemble_logic(successful_models)
166
 
167
+ return {
168
+ "emotion": analysis["primary"],
169
+ "confidence": analysis["confidence"],
170
+ "scores": analysis["distribution"],
171
+ "meta": {
172
+ "duration_sec": round(duration, 2),
173
+ "latency_sec": round(time.time() - start_time, 2),
174
+ "models_responding": len(successful_models)
175
+ }
 
 
 
 
176
  }
 
 
 
177
 
178
  if __name__ == "__main__":
 
179
  uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))