PsalmsJava commited on
Commit
9434231
·
1 Parent(s): c6d98fa

Made Some Changes

Browse files
DockerFile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ RUN apt-get update && apt-get install -y ffmpeg && rm -rf /var/lib/apt/lists/*
4
+ RUN pip install fastapi uvicorn aiohttp numpy
5
+
6
+ WORKDIR /app
7
+ COPY main.py .
8
+
9
+ CMD uvicorn main:app --host 0.0.0.0 --port 7860
app/audio/preprocessor.py DELETED
@@ -1,92 +0,0 @@
1
- import os
2
- import tempfile
3
- import subprocess
4
- import numpy as np
5
- import librosa
6
- from fastapi import UploadFile, HTTPException
7
- from typing import Tuple
8
- import logging
9
- from ..config import config
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
- class AudioPreprocessor:
14
- """Simplified audio preprocessing for Hugging Face"""
15
-
16
- def __init__(self):
17
- self.target_sr = config.AUDIO_CONFIG["target_sample_rate"]
18
- self.max_duration = config.AUDIO_CONFIG["max_duration"]
19
- self.max_size_mb = config.AUDIO_CONFIG["max_file_size_mb"]
20
-
21
- async def validate_and_preprocess(self, file: UploadFile) -> Tuple[np.ndarray, int, dict]:
22
- """
23
- Validate and preprocess audio file
24
- Simplified for Hugging Face deployment
25
- """
26
- # Read file
27
- contents = await file.read()
28
- file_size_mb = len(contents) / (1024 * 1024)
29
-
30
- # Validate size
31
- if file_size_mb > self.max_size_mb:
32
- raise HTTPException(
33
- status_code=400,
34
- detail=f"File too large: {file_size_mb:.1f}MB (max: {self.max_size_mb}MB)"
35
- )
36
-
37
- # Save to temp file
38
- with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.split('.')[-1]}") as tmp_input:
39
- tmp_input.write(contents)
40
- input_path = tmp_input.name
41
-
42
- try:
43
- # Convert to WAV using FFmpeg
44
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_output:
45
- output_path = tmp_output.name
46
-
47
- # FFmpeg command
48
- cmd = [
49
- "ffmpeg",
50
- "-i", input_path,
51
- "-ar", str(self.target_sr),
52
- "-ac", "1",
53
- "-acodec", "pcm_s16le",
54
- "-y",
55
- output_path
56
- ]
57
-
58
- result = subprocess.run(cmd, capture_output=True, text=True)
59
-
60
- if result.returncode != 0:
61
- raise HTTPException(status_code=400, detail="Audio conversion failed")
62
-
63
- # Load audio
64
- audio, sr = librosa.load(output_path, sr=self.target_sr)
65
-
66
- # Check duration
67
- duration = len(audio) / sr
68
- if duration > self.max_duration:
69
- audio = audio[:int(self.max_duration * sr)]
70
-
71
- # Simple normalization
72
- audio = audio / np.max(np.abs(audio)) if np.max(np.abs(audio)) > 0 else audio
73
-
74
- metadata = {
75
- "filename": file.filename,
76
- "duration": round(len(audio) / sr, 2),
77
- "sample_rate": sr,
78
- "size_mb": round(file_size_mb, 2)
79
- }
80
-
81
- return audio, sr, metadata
82
-
83
- except Exception as e:
84
- logger.error(f"Audio processing failed: {str(e)}")
85
- raise HTTPException(status_code=500, detail="Audio processing failed")
86
- finally:
87
- # Cleanup
88
- for path in [input_path, output_path]:
89
- if os.path.exists(path):
90
- os.unlink(path)
91
-
92
- audio_preprocessor = AudioPreprocessor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/auth.py DELETED
File without changes
app/config.py DELETED
File without changes
app/main.py DELETED
@@ -1,192 +0,0 @@
1
- from fastapi import FastAPI, File, UploadFile, Depends, HTTPException, Request
2
- from fastapi.security import HTTPBearer
3
- from fastapi.responses import JSONResponse
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.openapi.docs import get_swagger_ui_html
6
- import aiohttp
7
- import time
8
- import hashlib
9
- import logging
10
- from datetime import datetime
11
- from typing import Optional
12
- import os
13
-
14
- from .config import config
15
- from .auth import auth_handler
16
- from .audio.preprocessor import audio_preprocessor
17
- from .models.ensemble import ensemble_fusion
18
-
19
- # Setup logging
20
- logging.basicConfig(level=getattr(logging, config.LOG_LEVEL))
21
- logger = logging.getLogger(__name__)
22
-
23
- # Initialize FastAPI
24
- app = FastAPI(
25
- title=config.API_TITLE,
26
- version=config.API_VERSION,
27
- docs_url=None # We'll create custom docs
28
- )
29
-
30
- # CORS
31
- app.add_middleware(
32
- CORSMiddleware,
33
- allow_origins=["*"],
34
- allow_credentials=True,
35
- allow_methods=["*"],
36
- allow_headers=["*"],
37
- )
38
-
39
- # Simple cache
40
- prediction_cache = {}
41
- http_session = None
42
-
43
- @app.on_event("startup")
44
- async def startup():
45
- """Initialize on startup"""
46
- global http_session
47
- http_session = aiohttp.ClientSession()
48
- logger.info(f"🚀 API started with {len(config.MODELS)} models")
49
- logger.info(f"HF_TOKEN present: {bool(config.HF_TOKEN)}")
50
-
51
- @app.on_event("shutdown")
52
- async def shutdown():
53
- """Cleanup on shutdown"""
54
- if http_session:
55
- await http_session.close()
56
-
57
- @app.get("/", include_in_schema=False)
58
- async def root():
59
- """Root endpoint"""
60
- return {
61
- "service": config.API_TITLE,
62
- "version": config.API_VERSION,
63
- "docs": "/docs",
64
- "health": "/health"
65
- }
66
-
67
- @app.get("/health")
68
- async def health_check():
69
- """Health check for Hugging Face"""
70
- return {
71
- "status": "healthy",
72
- "timestamp": datetime.utcnow().isoformat(),
73
- "models": len(config.MODELS),
74
- "hf_token_configured": bool(config.HF_TOKEN)
75
- }
76
-
77
- @app.get("/models")
78
- async def list_models():
79
- """List all models in ensemble"""
80
- return {
81
- "models": [
82
- {
83
- "name": name,
84
- "weight": model["weight"],
85
- "description": model.get("description", "")
86
- }
87
- for name, model in config.MODELS.items()
88
- ]
89
- }
90
-
91
- @app.post("/analyze")
92
- async def analyze_emotion(
93
- request: Request,
94
- file: UploadFile = File(...),
95
- token: str = Depends(auth_handler.verify_token)
96
- ):
97
- """
98
- Analyze emotion from audio file
99
- """
100
- start_time = time.time()
101
-
102
- try:
103
- # Check cache
104
- file_content = await file.read()
105
- await file.seek(0)
106
-
107
- cache_key = hashlib.md5(file_content).hexdigest()
108
- if cache_key in prediction_cache:
109
- logger.info(f"Cache hit for {cache_key}")
110
- return prediction_cache[cache_key]
111
-
112
- # Process audio
113
- logger.info(f"Processing: {file.filename}")
114
- audio, sr, metadata = await audio_preprocessor.validate_and_preprocess(file)
115
-
116
- # Get file bytes again
117
- await file.seek(0)
118
- audio_bytes = await file.read()
119
-
120
- # Query models
121
- logger.info("Querying models...")
122
- model_outputs = await ensemble_fusion.query_all_models(http_session, audio_bytes)
123
-
124
- if len(model_outputs) < config.ENSEMBLE_CONFIG["min_models_for_prediction"]:
125
- raise HTTPException(
126
- status_code=503,
127
- detail=f"Only {len(model_outputs)}/{len(config.MODELS)} models responded"
128
- )
129
-
130
- # Fuse predictions
131
- result = ensemble_fusion.fuse_predictions(model_outputs)
132
-
133
- # Add metadata
134
- result["processing_time"] = round(time.time() - start_time, 2)
135
- result["audio_metadata"] = metadata
136
- result["timestamp"] = datetime.utcnow().isoformat()
137
-
138
- # Cache result
139
- if len(prediction_cache) < config.CACHE_CONFIG["max_size"]:
140
- prediction_cache[cache_key] = result
141
-
142
- logger.info(f"Analysis complete: {result['primary_emotion']} ({result['processing_time']}s)")
143
- return result
144
-
145
- except HTTPException:
146
- raise
147
- except Exception as e:
148
- logger.error(f"Error: {str(e)}", exc_info=True)
149
- raise HTTPException(status_code=500, detail=str(e))
150
-
151
- @app.get("/docs", include_in_schema=False)
152
- async def custom_docs():
153
- """Custom Swagger UI"""
154
- return get_swagger_ui_html(
155
- openapi_url="/openapi.json",
156
- title=f"{config.API_TITLE} - Docs",
157
- swagger_js_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js",
158
- swagger_css_url="https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css",
159
- )
160
-
161
- @app.get("/openapi.json", include_in_schema=False)
162
- async def get_openapi():
163
- """Custom OpenAPI schema"""
164
- from fastapi.openapi.utils import get_openapi
165
-
166
- if app.openapi_schema:
167
- return app.openapi_schema
168
-
169
- openapi_schema = get_openapi(
170
- title=config.API_TITLE,
171
- version=config.API_VERSION,
172
- description="Emotion detection API using ensemble of 5 models",
173
- routes=app.routes,
174
- )
175
-
176
- # Add security
177
- openapi_schema["components"]["securitySchemes"] = {
178
- "bearerAuth": {
179
- "type": "http",
180
- "scheme": "bearer",
181
- "description": "Enter your API token"
182
- }
183
- }
184
- openapi_schema["security"] = [{"bearerAuth": []}]
185
-
186
- app.openapi_schema = openapi_schema
187
- return app.openapi_schema
188
-
189
- if __name__ == "__main__":
190
- import uvicorn
191
- port = int(os.getenv("PORT", 8000))
192
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/models/ensemble.py DELETED
File without changes
app/utils/logger.py DELETED
@@ -1,12 +0,0 @@
1
- import logging
2
- import sys
3
-
4
- def setup_logging():
5
- """Simple logging setup"""
6
- logging.basicConfig(
7
- level=logging.INFO,
8
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
9
- handlers=[
10
- logging.StreamHandler(sys.stdout)
11
- ]
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import subprocess
4
+ from fastapi import FastAPI, File, UploadFile, HTTPException
5
+ from fastapi.responses import JSONResponse
6
+ import aiohttp
7
+ import numpy as np
8
+ from datetime import datetime
9
+ import logging
10
+
11
+ # Setup
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+ app = FastAPI(title="Emotion Detection API", docs_url="/docs")
15
+
16
+ # Config - get from environment
17
+ HF_TOKEN = os.getenv("HF_TOKEN", "")
18
+ API_TOKEN = os.getenv("API_TOKEN", "test123")
19
+
20
+ # Models - using only 2 for reliability
21
+ MODELS = {
22
+ "wav2vec2_english": {
23
+ "url": "https://api-inference.huggingface.co/models/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition",
24
+ "weight": 0.7,
25
+ },
26
+ "gigam_emo": {
27
+ "url": "https://api-inference.huggingface.co/models/salute-developers/GigaAM-emo",
28
+ "weight": 0.3,
29
+ }
30
+ }
31
+
32
+ # Emotion mapping
33
+ EMOTION_MAPPING = {
34
+ "angry": ["angry", "ang"],
35
+ "happy": ["happy", "hap"],
36
+ "sad": ["sad"],
37
+ "fear": ["fear"],
38
+ "surprise": ["surprise"],
39
+ "disgust": ["disgust"],
40
+ "neutral": ["neutral", "neu"]
41
+ }
42
+
43
+ @app.get("/health")
44
+ async def health():
45
+ return {"status": "ok", "hf_token": bool(HF_TOKEN)}
46
+
47
+ @app.get("/")
48
+ async def root():
49
+ return {
50
+ "message": "Emotion Detection API",
51
+ "docs": "/docs",
52
+ "endpoints": ["POST /analyze"]
53
+ }
54
+
55
+ @app.post("/analyze")
56
+ async def analyze(file: UploadFile = File(...)):
57
+ """Analyze emotion from audio file"""
58
+
59
+ # Check auth header
60
+ auth = file.headers.get("authorization", "")
61
+ if not auth or auth.replace("Bearer ", "") != API_TOKEN:
62
+ return JSONResponse(
63
+ status_code=401,
64
+ content={"error": "Invalid or missing Authorization header"}
65
+ )
66
+
67
+ # Save uploaded file
68
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
69
+ content = await file.read()
70
+ tmp.write(content)
71
+ input_path = tmp.name
72
+
73
+ try:
74
+ # Convert to proper format
75
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as out:
76
+ output_path = out.name
77
+
78
+ subprocess.run([
79
+ "ffmpeg", "-i", input_path,
80
+ "-ar", "16000", "-ac", "1",
81
+ "-y", output_path
82
+ ], check=True, capture_output=True)
83
+
84
+ # Read converted file
85
+ with open(output_path, "rb") as f:
86
+ audio_bytes = f.read()
87
+
88
+ # Query models
89
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
90
+ results = {}
91
+
92
+ async with aiohttp.ClientSession() as session:
93
+ for name, config in MODELS.items():
94
+ try:
95
+ async with session.post(
96
+ config["url"],
97
+ headers=headers,
98
+ data=audio_bytes,
99
+ timeout=10
100
+ ) as resp:
101
+ if resp.status == 200:
102
+ results[name] = await resp.json()
103
+ except Exception as e:
104
+ logger.warning(f"{name} failed: {e}")
105
+
106
+ # Simple ensemble
107
+ emotion_scores = {}
108
+ total_weight = 0
109
+
110
+ for name, predictions in results.items():
111
+ weight = MODELS[name]["weight"]
112
+ total_weight += weight
113
+
114
+ for pred in predictions:
115
+ label = pred.get("label", "").lower()
116
+ score = pred.get("score", 0)
117
+
118
+ # Map to standard emotions
119
+ for std_emo, variations in EMOTION_MAPPING.items():
120
+ if any(v in label for v in variations):
121
+ emotion_scores[std_emo] = emotion_scores.get(std_emo, 0) + score * weight
122
+ break
123
+
124
+ # Normalize
125
+ if total_weight > 0:
126
+ emotion_scores = {k: v/total_weight for k, v in emotion_scores.items()}
127
+
128
+ # Get primary emotion
129
+ primary = max(emotion_scores.items(), key=lambda x: x[1]) if emotion_scores else ("unknown", 0)
130
+
131
+ return {
132
+ "primary_emotion": primary[0],
133
+ "confidence": round(primary[1], 3),
134
+ "all_emotions": {k: round(v, 3) for k, v in emotion_scores.items()},
135
+ "models_used": list(results.keys())
136
+ }
137
+
138
+ except Exception as e:
139
+ logger.error(f"Error: {e}")
140
+ return JSONResponse(status_code=500, content={"error": str(e)})
141
+ finally:
142
+ # Cleanup
143
+ for path in [input_path, output_path]:
144
+ if os.path.exists(path):
145
+ os.unlink(path)
146
+
147
+ # For Hugging Face
148
+ from fastapi.middleware.cors import CORSMiddleware
149
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
requirements.txt CHANGED
@@ -1,28 +1,4 @@
1
- # Core Framework
2
- fastapi==0.104.1
3
- uvicorn[standard]==0.24.0
4
- pydantic==2.5.0
5
- python-dotenv==1.0.0
6
-
7
- # Authentication
8
- python-jose[cryptography]==3.3.0
9
- passlib[bcrypt]==1.7.4
10
- python-multipart==0.0.6
11
-
12
- # HTTP & Async
13
- aiohttp==3.9.1
14
- httpx==0.25.1
15
-
16
- # Audio Processing
17
- librosa==0.10.1
18
- soundfile==0.12.1
19
- pydub==0.25.1
20
- ffmpeg-python==0.2.0
21
-
22
- # Scientific Computing
23
- numpy==1.24.3
24
- scipy==1.11.4
25
- scikit-learn==1.3.2
26
-
27
- # Rate Limiting
28
- slowapi==0.1.8
 
1
+ fastapi
2
+ uvicorn
3
+ aiohttp
4
+ numpy