import tempfile import os import gc from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import torch.nn.functional as F import torchaudio import torch from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC from src.transcription import SpeechEncoder from src.sentiment import TextEncoder from src.multimodal import MultimodalSentimentClassifier # Configuration pour Hugging Face Spaces HF_SPACE = os.getenv("HF_SPACE", "false").lower() == "true" app = FastAPI( title="API Multimodale de Transcription & Sentiment", description="API pour l'analyse de sentiment audio en français", version="1.0", docs_url="/docs", redoc_url="/redoc" ) # Configuration CORS pour Hugging Face Spaces app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Précharge des modèles print("Chargement des modèles pour l'API...") try: processor_ctc = Wav2Vec2Processor.from_pretrained( "jonatasgrosman/wav2vec2-large-xlsr-53-french", cache_dir="./models" if not HF_SPACE else None ) model_ctc = Wav2Vec2ForCTC.from_pretrained( "jonatasgrosman/wav2vec2-large-xlsr-53-french", cache_dir="./models" if not HF_SPACE else None ) speech_enc = SpeechEncoder() text_enc = TextEncoder() model_mm = MultimodalSentimentClassifier() print("✅ Modèles chargés avec succès pour l'API") except Exception as e: print(f"❌ Erreur chargement modèles API: {e}") raise def transcribe_ctc(wav_path: str) -> str: """Transcription audio avec Wav2Vec2""" try: waveform, sr = torchaudio.load(wav_path) if sr != 16000: waveform = torchaudio.transforms.Resample(sr, 16000)(waveform) if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True) inputs = processor_ctc( waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True ) with torch.no_grad(): logits = model_ctc(**inputs).logits pred_ids = torch.argmax(logits, dim=-1) transcription = processor_ctc.batch_decode(pred_ids)[0].lower() return transcription except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur transcription: {str(e)}") @app.get("/") async def root(): """Endpoint racine avec informations sur l'API""" return { "message": "API Multimodale de Transcription & Sentiment", "version": "1.0", "endpoints": { "docs": "/docs", "predict": "/predict", "health": "/health" }, "supported_formats": ["wav", "flac", "mp3"] } @app.get("/health") async def health_check(): """Vérification de l'état de l'API""" return { "status": "healthy", "models_loaded": True, "timestamp": "2024-01-01T00:00:00Z" } @app.post("/predict") async def predict(file: UploadFile = File(...)): """ Analyse de sentiment audio Args: file: Fichier audio (WAV, FLAC, MP3) Returns: JSON avec transcription et sentiment """ # 1. Vérifier le type de fichier if not file.filename or not file.filename.lower().endswith((".wav", ".flac", ".mp3")): raise HTTPException( status_code=400, detail="Seuls les fichiers audio WAV/FLAC/MP3 sont acceptés." ) # 2. Vérifier la taille du fichier (max 50MB) content = await file.read() if len(content) > 50 * 1024 * 1024: # 50MB raise HTTPException( status_code=400, detail="Fichier trop volumineux. Taille maximale: 50MB" ) # 3. Sauvegarder temporairement suffix = os.path.splitext(file.filename)[1] with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: tmp.write(content) tmp_path = tmp.name try: # 4. Transcription transcription = transcribe_ctc(tmp_path) if not transcription.strip(): return JSONResponse({ "transcription": "", "sentiment": {"négatif": 0.33, "neutre": 0.34, "positif": 0.33}, "warning": "Aucune transcription détectée" }) # 5. Features multimodales try: audio_feat = speech_enc.extract_features(tmp_path) text_feat = text_enc.extract_features([transcription]) # 6. Classification logits = model_mm.classifier(torch.cat([audio_feat, text_feat], dim=1)) probs = F.softmax(logits, dim=1).squeeze().tolist() labels = ["négatif", "neutre", "positif"] sentiment = {labels[i]: round(probs[i], 3) for i in range(len(labels))} except Exception as e: # Fallback vers analyse textuelle uniquement print(f"Erreur multimodal, fallback textuel: {e}") sent_dict = TextEncoder.analyze_sentiment(transcription) sentiment = {k: round(v, 3) for k, v in sent_dict.items()} # 7. Nettoyage mémoire gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return JSONResponse({ "transcription": transcription, "sentiment": sentiment, "filename": file.filename, "file_size": len(content) }) except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur lors de l'analyse: {str(e)}") finally: # 8. Nettoyage fichier temporaire try: os.remove(tmp_path) except: pass @app.post("/predict_text") async def predict_text(text: str): """ Analyse de sentiment textuel uniquement Args: text: Texte à analyser Returns: JSON avec sentiment """ try: sent_dict = TextEncoder.analyze_sentiment(text) sentiment = {k: round(v, 3) for k, v in sent_dict.items()} return JSONResponse({ "text": text, "sentiment": sentiment }) except Exception as e: raise HTTPException(status_code=500, detail=f"Erreur analyse textuelle: {str(e)}") # Configuration pour Hugging Face Spaces if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0" if HF_SPACE else "127.0.0.1", port=8000, log_level="info" )