Spaces:
Runtime error
Runtime error
File size: 4,112 Bytes
66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 66a046b 82a5961 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
import logging
import torch
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from langdetect import detect
import requests
logging.basicConfig(level=logging.DEBUG)
# Initialize FastAPI
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load sentiment models (these can likely remain as they are)
arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter"
sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT")
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
# Input class for POST body
class Message(BaseModel):
text: str
# Language detection (remains the same)
def detect_language_safely(text):
try:
if any('\u0600' <= c <= '\u06FF' for c in text):
return "ar"
return detect(text)
except:
return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en"
# Sentiment to emotion mapping (remains the same)
def map_sentiment_to_emotion(sentiment, language="en"):
if language == "ar":
return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral"
return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral"
# Simple Arabic sentiment analysis (remains the same)
def arabic_sentiment_analysis(text):
pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"]
neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"]
pos_count = sum(1 for word in pos_words if word in text.lower())
neg_count = sum(1 for word in neg_words if word in text.lower())
if pos_count > neg_count:
return "positive"
elif neg_count > pos_count:
return "negative"
else:
try:
inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
outputs = sentiment_model(**inputs)
sentiment_class = torch.argmax(outputs.logits).item()
return ["negative", "neutral", "positive"][sentiment_class]
except:
return "neutral"
# TTS Server URL
TTS_SERVER_URL = "http://localhost:5002/tts"
# Main TTS endpoint (now sends request to the TTS server)
@app.post("/text-to-speech/")
async def text_to_speech(msg: Message):
text = msg.text
language = detect_language_safely(text)
emotion = "neutral"
speaker_wav_path = "/app/audio/speaker_reference.wav"
if language == "en":
try:
sentiment_result = sentiment_analyzer(text)[0]
emotion = map_sentiment_to_emotion(sentiment_result["label"])
except:
pass
else:
try:
sentiment_result = arabic_sentiment_analysis(text)
emotion = map_sentiment_to_emotion(sentiment_result, language="ar")
except:
pass
payload = {
"text": text,
"language": language,
"speaker_wav": speaker_wav_path,
"emotion": emotion,
}
try:
response = requests.post(TTS_SERVER_URL, json=payload, stream=True)
response.raise_for_status()
return StreamingResponse(response.iter_content(chunk_size=1024), media_type="audio/wav")
except requests.exceptions.RequestException as e:
logging.error(f"Error communicating with TTS server: {e}")
raise HTTPException(status_code=500, detail="Failed to generate speech")
# Serve static files (your web page) from the 'web' directory
app.mount("/", StaticFiles(directory="web", html=True), name="static") |