Spaces:
Runtime error
Runtime error
import tempfile | |
import os | |
import gc | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch.nn.functional as F | |
import torchaudio | |
import torch | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
from src.transcription import SpeechEncoder | |
from src.sentiment import TextEncoder | |
from src.multimodal import MultimodalSentimentClassifier | |
# Configuration pour Hugging Face Spaces | |
HF_SPACE = os.getenv("HF_SPACE", "false").lower() == "true" | |
app = FastAPI( | |
title="API Multimodale de Transcription & Sentiment", | |
description="API pour l'analyse de sentiment audio en français", | |
version="1.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Configuration CORS pour Hugging Face Spaces | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Précharge des modèles | |
print("Chargement des modèles pour l'API...") | |
try: | |
processor_ctc = Wav2Vec2Processor.from_pretrained( | |
"jonatasgrosman/wav2vec2-large-xlsr-53-french", | |
cache_dir="./models" if not HF_SPACE else None | |
) | |
model_ctc = Wav2Vec2ForCTC.from_pretrained( | |
"jonatasgrosman/wav2vec2-large-xlsr-53-french", | |
cache_dir="./models" if not HF_SPACE else None | |
) | |
speech_enc = SpeechEncoder() | |
text_enc = TextEncoder() | |
model_mm = MultimodalSentimentClassifier() | |
print("✅ Modèles chargés avec succès pour l'API") | |
except Exception as e: | |
print(f"❌ Erreur chargement modèles API: {e}") | |
raise | |
def transcribe_ctc(wav_path: str) -> str: | |
"""Transcription audio avec Wav2Vec2""" | |
try: | |
waveform, sr = torchaudio.load(wav_path) | |
if sr != 16000: | |
waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) | |
if waveform.size(0) > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
inputs = processor_ctc( | |
waveform.squeeze().numpy(), | |
sampling_rate=16000, | |
return_tensors="pt", | |
padding=True | |
) | |
with torch.no_grad(): | |
logits = model_ctc(**inputs).logits | |
pred_ids = torch.argmax(logits, dim=-1) | |
transcription = processor_ctc.batch_decode(pred_ids)[0].lower() | |
return transcription | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Erreur transcription: {str(e)}") | |
async def root(): | |
"""Endpoint racine avec informations sur l'API""" | |
return { | |
"message": "API Multimodale de Transcription & Sentiment", | |
"version": "1.0", | |
"endpoints": { | |
"docs": "/docs", | |
"predict": "/predict", | |
"health": "/health" | |
}, | |
"supported_formats": ["wav", "flac", "mp3"] | |
} | |
async def health_check(): | |
"""Vérification de l'état de l'API""" | |
return { | |
"status": "healthy", | |
"models_loaded": True, | |
"timestamp": "2024-01-01T00:00:00Z" | |
} | |
async def predict(file: UploadFile = File(...)): | |
""" | |
Analyse de sentiment audio | |
Args: | |
file: Fichier audio (WAV, FLAC, MP3) | |
Returns: | |
JSON avec transcription et sentiment | |
""" | |
# 1. Vérifier le type de fichier | |
if not file.filename or not file.filename.lower().endswith((".wav", ".flac", ".mp3")): | |
raise HTTPException( | |
status_code=400, | |
detail="Seuls les fichiers audio WAV/FLAC/MP3 sont acceptés." | |
) | |
# 2. Vérifier la taille du fichier (max 50MB) | |
content = await file.read() | |
if len(content) > 50 * 1024 * 1024: # 50MB | |
raise HTTPException( | |
status_code=400, | |
detail="Fichier trop volumineux. Taille maximale: 50MB" | |
) | |
# 3. Sauvegarder temporairement | |
suffix = os.path.splitext(file.filename)[1] | |
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: | |
tmp.write(content) | |
tmp_path = tmp.name | |
try: | |
# 4. Transcription | |
transcription = transcribe_ctc(tmp_path) | |
if not transcription.strip(): | |
return JSONResponse({ | |
"transcription": "", | |
"sentiment": {"négatif": 0.33, "neutre": 0.34, "positif": 0.33}, | |
"warning": "Aucune transcription détectée" | |
}) | |
# 5. Features multimodales | |
try: | |
audio_feat = speech_enc.extract_features(tmp_path) | |
text_feat = text_enc.extract_features([transcription]) | |
# 6. Classification | |
logits = model_mm.classifier(torch.cat([audio_feat, text_feat], dim=1)) | |
probs = F.softmax(logits, dim=1).squeeze().tolist() | |
labels = ["négatif", "neutre", "positif"] | |
sentiment = {labels[i]: round(probs[i], 3) for i in range(len(labels))} | |
except Exception as e: | |
# Fallback vers analyse textuelle uniquement | |
print(f"Erreur multimodal, fallback textuel: {e}") | |
sent_dict = TextEncoder.analyze_sentiment(transcription) | |
sentiment = {k: round(v, 3) for k, v in sent_dict.items()} | |
# 7. Nettoyage mémoire | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return JSONResponse({ | |
"transcription": transcription, | |
"sentiment": sentiment, | |
"filename": file.filename, | |
"file_size": len(content) | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Erreur lors de l'analyse: {str(e)}") | |
finally: | |
# 8. Nettoyage fichier temporaire | |
try: | |
os.remove(tmp_path) | |
except: | |
pass | |
async def predict_text(text: str): | |
""" | |
Analyse de sentiment textuel uniquement | |
Args: | |
text: Texte à analyser | |
Returns: | |
JSON avec sentiment | |
""" | |
try: | |
sent_dict = TextEncoder.analyze_sentiment(text) | |
sentiment = {k: round(v, 3) for k, v in sent_dict.items()} | |
return JSONResponse({ | |
"text": text, | |
"sentiment": sentiment | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Erreur analyse textuelle: {str(e)}") | |
# Configuration pour Hugging Face Spaces | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
app, | |
host="0.0.0.0" if HF_SPACE else "127.0.0.1", | |
port=8000, | |
log_level="info" | |
) |