Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from fastapi.responses import FileResponse, StreamingResponse | |
from fastapi.staticfiles import StaticFiles | |
import logging | |
import torch | |
import os | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
from langdetect import detect | |
import requests | |
logging.basicConfig(level=logging.DEBUG) | |
# Initialize FastAPI | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load sentiment models (these can likely remain as they are) | |
arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter" | |
sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name) | |
sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT") | |
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") | |
# Input class for POST body | |
class Message(BaseModel): | |
text: str | |
# Language detection (remains the same) | |
def detect_language_safely(text): | |
try: | |
if any('\u0600' <= c <= '\u06FF' for c in text): | |
return "ar" | |
return detect(text) | |
except: | |
return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en" | |
# Sentiment to emotion mapping (remains the same) | |
def map_sentiment_to_emotion(sentiment, language="en"): | |
if language == "ar": | |
return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral" | |
return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral" | |
# Simple Arabic sentiment analysis (remains the same) | |
def arabic_sentiment_analysis(text): | |
pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"] | |
neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"] | |
pos_count = sum(1 for word in pos_words if word in text.lower()) | |
neg_count = sum(1 for word in neg_words if word in text.lower()) | |
if pos_count > neg_count: | |
return "positive" | |
elif neg_count > pos_count: | |
return "negative" | |
else: | |
try: | |
inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) | |
outputs = sentiment_model(**inputs) | |
sentiment_class = torch.argmax(outputs.logits).item() | |
return ["negative", "neutral", "positive"][sentiment_class] | |
except: | |
return "neutral" | |
# TTS Server URL | |
TTS_SERVER_URL = "http://localhost:5002/tts" | |
# Main TTS endpoint (now sends request to the TTS server) | |
async def text_to_speech(msg: Message): | |
text = msg.text | |
language = detect_language_safely(text) | |
emotion = "neutral" | |
speaker_wav_path = "/app/audio/speaker_reference.wav" | |
if language == "en": | |
try: | |
sentiment_result = sentiment_analyzer(text)[0] | |
emotion = map_sentiment_to_emotion(sentiment_result["label"]) | |
except: | |
pass | |
else: | |
try: | |
sentiment_result = arabic_sentiment_analysis(text) | |
emotion = map_sentiment_to_emotion(sentiment_result, language="ar") | |
except: | |
pass | |
payload = { | |
"text": text, | |
"language": language, | |
"speaker_wav": speaker_wav_path, | |
"emotion": emotion, | |
} | |
try: | |
response = requests.post(TTS_SERVER_URL, json=payload, stream=True) | |
response.raise_for_status() | |
return StreamingResponse(response.iter_content(chunk_size=1024), media_type="audio/wav") | |
except requests.exceptions.RequestException as e: | |
logging.error(f"Error communicating with TTS server: {e}") | |
raise HTTPException(status_code=500, detail="Failed to generate speech") | |
# Serve static files (your web page) from the 'web' directory | |
app.mount("/", StaticFiles(directory="web", html=True), name="static") |