File size: 4,112 Bytes
66a046b
82a5961
 
66a046b
82a5961
 
 
 
 
 
66a046b
82a5961
 
 
 
 
 
 
 
 
 
 
 
66a046b
82a5961
 
 
 
 
 
 
 
 
66a046b
82a5961
 
 
 
 
 
 
 
66a046b
82a5961
 
 
 
 
66a046b
82a5961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66a046b
 
 
 
82a5961
66a046b
82a5961
 
 
66a046b
82a5961
 
 
 
 
 
 
 
 
 
 
 
 
 
66a046b
 
 
 
 
 
82a5961
66a046b
 
 
 
 
 
82a5961
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
import logging
import torch
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from langdetect import detect
import requests

logging.basicConfig(level=logging.DEBUG)

# Initialize FastAPI
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load sentiment models (these can likely remain as they are)
arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter"
sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT")
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")

# Input class for POST body
class Message(BaseModel):
    text: str

# Language detection (remains the same)
def detect_language_safely(text):
    try:
        if any('\u0600' <= c <= '\u06FF' for c in text):
            return "ar"
        return detect(text)
    except:
        return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en"

# Sentiment to emotion mapping (remains the same)
def map_sentiment_to_emotion(sentiment, language="en"):
    if language == "ar":
        return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral"
    return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral"

# Simple Arabic sentiment analysis (remains the same)
def arabic_sentiment_analysis(text):
    pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"]
    neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"]
    pos_count = sum(1 for word in pos_words if word in text.lower())
    neg_count = sum(1 for word in neg_words if word in text.lower())

    if pos_count > neg_count:
        return "positive"
    elif neg_count > pos_count:
        return "negative"
    else:
        try:
            inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
            outputs = sentiment_model(**inputs)
            sentiment_class = torch.argmax(outputs.logits).item()
            return ["negative", "neutral", "positive"][sentiment_class]
        except:
            return "neutral"

# TTS Server URL
TTS_SERVER_URL = "http://localhost:5002/tts"

# Main TTS endpoint (now sends request to the TTS server)
@app.post("/text-to-speech/")
async def text_to_speech(msg: Message):
    text = msg.text
    language = detect_language_safely(text)
    emotion = "neutral"
    speaker_wav_path = "/app/audio/speaker_reference.wav"

    if language == "en":
        try:
            sentiment_result = sentiment_analyzer(text)[0]
            emotion = map_sentiment_to_emotion(sentiment_result["label"])
        except:
            pass
    else:
        try:
            sentiment_result = arabic_sentiment_analysis(text)
            emotion = map_sentiment_to_emotion(sentiment_result, language="ar")
        except:
            pass

    payload = {
        "text": text,
        "language": language,
        "speaker_wav": speaker_wav_path,
        "emotion": emotion,
    }
    try:
        response = requests.post(TTS_SERVER_URL, json=payload, stream=True)
        response.raise_for_status()
        return StreamingResponse(response.iter_content(chunk_size=1024), media_type="audio/wav")
    except requests.exceptions.RequestException as e:
        logging.error(f"Error communicating with TTS server: {e}")
        raise HTTPException(status_code=500, detail="Failed to generate speech")

# Serve static files (your web page) from the 'web' directory
app.mount("/", StaticFiles(directory="web", html=True), name="static")