File size: 4,634 Bytes
f268d74
4258aad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3284a8
f268d74
ca4390d
c3284a8
ca4390d
 
 
 
f268d74
ca4390d
 
 
 
548d9f6
ca4390d
f268d74
 
ca4390d
 
 
 
 
 
 
 
f268d74
ca4390d
 
 
 
b98abdc
f268d74
 
 
ca4390d
f268d74
 
 
 
 
 
 
b98abdc
f268d74
ca4390d
f268d74
 
 
 
 
ca4390d
f268d74
 
 
c6793c5
f268d74
ca4390d
 
 
 
 
 
f268d74
ca4390d
 
 
 
 
849e3ef
ca4390d
 
 
79d8ea8
f268d74
 
 
79d8ea8
ca4390d
5b7803a
ca4390d
 
c3284a8
ca4390d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
849e3ef
b98abdc
c3284a8
ca4390d
 
 
 
548d9f6
ca4390d
 
 
5b7803a
 
b98abdc
5b7803a
ca4390d
5b7803a
 
 
 
79d8ea8
849e3ef
b98abdc
c3284a8
ca4390d
 
548d9f6
ca4390d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os

TMP_PATH = "/tmp"

os.environ["HF_HOME"] = os.path.join(TMP_PATH, "huggingface")
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"] + "/transformers"
os.environ["HF_DATASETS_CACHE"] = os.environ["HF_HOME"] + "/datasets"
os.environ["HF_METRICS_CACHE"] = os.environ["HF_HOME"] + "/metrics"
os.environ["MPLCONFIGDIR"] = os.path.join(TMP_PATH, "matplotlib")
os.environ["TORCH_HOME"] = os.path.join(TMP_PATH, "torch")
os.environ["XDG_CACHE_HOME"] = os.path.join(TMP_PATH, "xdg-cache")
os.environ["HOME"] = TMP_PATH

# Création des dossiers si nécessaires
for path in [
    os.environ["HF_HOME"],
    os.environ["MPLCONFIGDIR"],
    os.environ["TORCH_HOME"],
    os.environ["XDG_CACHE_HOME"],
    TMP_PATH
]:
    os.makedirs(path, exist_ok=True)

# --- Imports principaux ---
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline, set_seed
from deep_translator import GoogleTranslator
from TTS.api import TTS
import whisper
import io
import torch
import scipy.io.wavfile
import numpy as np
import traceback

# --- Init FastAPI ---
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- Config ---
device = "cuda" if torch.cuda.is_available() else "cpu"
TEMP_DIR = "/tmp"
os.makedirs(TEMP_DIR, exist_ok=True)

# --- Modèles chargés dynamiquement ---
loaded_tts_models = {}
loaded_tts_lingala = None
asr_model = None

# --- Utilitaires ---
def load_tts_model(lang: str):
    if lang not in loaded_tts_models:
        model_id = f"OlameMend/mms-tts-{lang}"
        try:
            loaded_tts_models[lang] = pipeline("text-to-audio", model=model_id)
        except Exception as e:
            raise RuntimeError(f"Erreur lors du chargement du modèle TTS '{model_id}': {e}")
    return loaded_tts_models[lang]

def load_asr_model():
    global asr_model
    if asr_model is None:
        asr_model = whisper.load_model("tiny")
    return asr_model

def load_tts_lingala_model():
    global loaded_tts_lingala
    if loaded_tts_lingala is None:
        loaded_tts_lingala = TTS("tts_models/lin/openbible/vits")
    return loaded_tts_lingala

def preprocess_text_tts(text: str) -> str:
    return text.strip()

def generate_tts_audio(lang: str, text: str) -> io.BytesIO:
    lang = lang.lower()
    synthesizer = load_tts_model(lang)
    processed_text = preprocess_text_tts(text)
    set_seed(555)
    speech = synthesizer(processed_text)

    wav_io = io.BytesIO()
    scipy.io.wavfile.write(wav_io, rate=speech["sampling_rate"], data=speech["audio"][0])
    wav_io.seek(0)
    return wav_io

def speech_2_speech_ling(source_audio_path: str, lang: str) -> io.BytesIO:
    asr = load_asr_model()
    tts_lingala = load_tts_lingala_model()

    result = asr.transcribe(source_audio_path, language=lang)
    text = result["text"]

    translated_text = GoogleTranslator(source="auto", target="ln").translate(text)
    wav_io = io.BytesIO()
    tts_lingala.tts_with_vc_to_file(text=translated_text, speaker_wav=source_audio_path, file_path=wav_io)
    wav_io.seek(0)
    return wav_io

# --- Endpoints ---
@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.post("/tts/")
async def api_tts(
    lang: str = Form(...),
    text: str = Form(None),
    file: UploadFile = File(None)
):
    try:
        if file:
            content = await file.read()
            text = content.decode("utf-8")
        if not text:
            raise ValueError("Aucun texte fourni (champ texte ou fichier manquant).")

        wav_io = generate_tts_audio(lang, text)
        wav_io.seek(0)

        return StreamingResponse(wav_io, media_type="audio/wav")

    except ValueError as ve:
        raise HTTPException(status_code=400, detail=str(ve))
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Erreur TTS : {str(e)}")

@app.post("/speech-to-speech/")
async def api_s2st(
    source_audio: UploadFile = File(...),
    lang: str = Form(...)
):
    try:
        source_path = os.path.join(TEMP_DIR, "source_" + source_audio.filename)
        with open(source_path, "wb") as f:
            f.write(await source_audio.read())

        wav_io = speech_2_speech_ling(source_path, lang)
        wav_io.seek(0)

        return StreamingResponse(wav_io, media_type="audio/wav")

    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Erreur S2ST : {str(e)}")