from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles import logging import torch import os from TTS.api import TTS from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline from langdetect import detect import soundfile as sf # Import soundfile # Allowlist XttsConfig so torch.load doesn't raise UnpicklingError from torch.serialization import add_safe_globals from TTS.tts.configs.xtts_config import XttsConfig add_safe_globals([XttsConfig]) # ✅ Monkey-patch torch.load to always use weights_only=False _original_torch_load = torch.load def patched_torch_load(*args, **kwargs): kwargs["weights_only"] = False return _original_torch_load(*args, **kwargs) torch.load = patched_torch_load logging.basicConfig(level=logging.DEBUG) # Initialize FastAPI app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Load TTS model from local files try: model_dir = "/app/models/xtts_v2" config_path = os.path.join(model_dir, "config.json") # When providing config_path, TTS might expect the directory for model_path tts = TTS(model_path=model_dir, config_path=config_path).to("cuda" if torch.cuda.is_available() else "cpu") print("XTTS v2 model loaded successfully from local files.") except Exception as e: print(f"Error loading XTTS v2 model from local files: {e}") print("Falling back to loading by model name (license might be required).") tts = TTS("tts_models/multilingual/multi-dataset-xtts_v2").to("cuda" if torch.cuda.is_available() else "cpu") # Load sentiment models arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter" sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name) sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT") sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") # Input class for POST body class Message(BaseModel): text: str # Language detection def detect_language_safely(text): try: if any('\u0600' <= c <= '\u06FF' for c in text): return "ar" return detect(text) except: return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en" # Sentiment to emotion mapping def map_sentiment_to_emotion(sentiment, language="en"): if language == "ar": return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral" return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral" # Simple Arabic sentiment analysis def arabic_sentiment_analysis(text): pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"] neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"] pos_count = sum(1 for word in pos_words if word in text.lower()) neg_count = sum(1 for word in neg_words if word in text.lower()) if pos_count > neg_count: return "positive" elif neg_count > pos_count: return "negative" else: try: inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) outputs = sentiment_model(**inputs) sentiment_class = torch.argmax(outputs.logits).item() return ["negative", "neutral", "positive"][sentiment_class] except: return "neutral" # Main TTS endpoint @app.post("/text-to-speech/") def text_to_speech(msg: Message): text = msg.text language = detect_language_safely(text) emotion = "neutral" if language == "en": try: sentiment_result = sentiment_analyzer(text)[0] emotion = map_sentiment_to_emotion(sentiment_result["label"]) except: pass else: try: sentiment_result = arabic_sentiment_analysis(text) emotion = map_sentiment_to_emotion(sentiment_result, language="ar") except: pass output_filename = "output.wav" try: tts.tts_to_file( text=text, file_path=output_filename, emotion=emotion, speaker_wav="/app/audio/speaker_reference.wav", # Updated path language=language ) return { "status": "success", "audio_file": output_filename, "url": "/audio" } except Exception as e: return {"status": "error", "message": str(e)} # ✅ Serve the audio file @app.get("/audio") def get_audio(): return FileResponse("output.wav", media_type="audio/wav", filename="output.wav") # Serve static files (your web page) from the 'web' directory app.mount("/", StaticFiles(directory="web", html=True), name="static")