audio-sentiment / api_app.py
alec228's picture
Initial commit
efca0f4
import tempfile
import os
import gc
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import torch.nn.functional as F
import torchaudio
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from src.transcription import SpeechEncoder
from src.sentiment import TextEncoder
from src.multimodal import MultimodalSentimentClassifier
# Configuration pour Hugging Face Spaces
HF_SPACE = os.getenv("HF_SPACE", "false").lower() == "true"
app = FastAPI(
title="API Multimodale de Transcription & Sentiment",
description="API pour l'analyse de sentiment audio en français",
version="1.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Configuration CORS pour Hugging Face Spaces
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Précharge des modèles
print("Chargement des modèles pour l'API...")
try:
processor_ctc = Wav2Vec2Processor.from_pretrained(
"jonatasgrosman/wav2vec2-large-xlsr-53-french",
cache_dir="./models" if not HF_SPACE else None
)
model_ctc = Wav2Vec2ForCTC.from_pretrained(
"jonatasgrosman/wav2vec2-large-xlsr-53-french",
cache_dir="./models" if not HF_SPACE else None
)
speech_enc = SpeechEncoder()
text_enc = TextEncoder()
model_mm = MultimodalSentimentClassifier()
print("✅ Modèles chargés avec succès pour l'API")
except Exception as e:
print(f"❌ Erreur chargement modèles API: {e}")
raise
def transcribe_ctc(wav_path: str) -> str:
"""Transcription audio avec Wav2Vec2"""
try:
waveform, sr = torchaudio.load(wav_path)
if sr != 16000:
waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
inputs = processor_ctc(
waveform.squeeze().numpy(),
sampling_rate=16000,
return_tensors="pt",
padding=True
)
with torch.no_grad():
logits = model_ctc(**inputs).logits
pred_ids = torch.argmax(logits, dim=-1)
transcription = processor_ctc.batch_decode(pred_ids)[0].lower()
return transcription
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur transcription: {str(e)}")
@app.get("/")
async def root():
"""Endpoint racine avec informations sur l'API"""
return {
"message": "API Multimodale de Transcription & Sentiment",
"version": "1.0",
"endpoints": {
"docs": "/docs",
"predict": "/predict",
"health": "/health"
},
"supported_formats": ["wav", "flac", "mp3"]
}
@app.get("/health")
async def health_check():
"""Vérification de l'état de l'API"""
return {
"status": "healthy",
"models_loaded": True,
"timestamp": "2024-01-01T00:00:00Z"
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Analyse de sentiment audio
Args:
file: Fichier audio (WAV, FLAC, MP3)
Returns:
JSON avec transcription et sentiment
"""
# 1. Vérifier le type de fichier
if not file.filename or not file.filename.lower().endswith((".wav", ".flac", ".mp3")):
raise HTTPException(
status_code=400,
detail="Seuls les fichiers audio WAV/FLAC/MP3 sont acceptés."
)
# 2. Vérifier la taille du fichier (max 50MB)
content = await file.read()
if len(content) > 50 * 1024 * 1024: # 50MB
raise HTTPException(
status_code=400,
detail="Fichier trop volumineux. Taille maximale: 50MB"
)
# 3. Sauvegarder temporairement
suffix = os.path.splitext(file.filename)[1]
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
# 4. Transcription
transcription = transcribe_ctc(tmp_path)
if not transcription.strip():
return JSONResponse({
"transcription": "",
"sentiment": {"négatif": 0.33, "neutre": 0.34, "positif": 0.33},
"warning": "Aucune transcription détectée"
})
# 5. Features multimodales
try:
audio_feat = speech_enc.extract_features(tmp_path)
text_feat = text_enc.extract_features([transcription])
# 6. Classification
logits = model_mm.classifier(torch.cat([audio_feat, text_feat], dim=1))
probs = F.softmax(logits, dim=1).squeeze().tolist()
labels = ["négatif", "neutre", "positif"]
sentiment = {labels[i]: round(probs[i], 3) for i in range(len(labels))}
except Exception as e:
# Fallback vers analyse textuelle uniquement
print(f"Erreur multimodal, fallback textuel: {e}")
sent_dict = TextEncoder.analyze_sentiment(transcription)
sentiment = {k: round(v, 3) for k, v in sent_dict.items()}
# 7. Nettoyage mémoire
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return JSONResponse({
"transcription": transcription,
"sentiment": sentiment,
"filename": file.filename,
"file_size": len(content)
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur lors de l'analyse: {str(e)}")
finally:
# 8. Nettoyage fichier temporaire
try:
os.remove(tmp_path)
except:
pass
@app.post("/predict_text")
async def predict_text(text: str):
"""
Analyse de sentiment textuel uniquement
Args:
text: Texte à analyser
Returns:
JSON avec sentiment
"""
try:
sent_dict = TextEncoder.analyze_sentiment(text)
sentiment = {k: round(v, 3) for k, v in sent_dict.items()}
return JSONResponse({
"text": text,
"sentiment": sentiment
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Erreur analyse textuelle: {str(e)}")
# Configuration pour Hugging Face Spaces
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0" if HF_SPACE else "127.0.0.1",
port=8000,
log_level="info"
)