Spaces:
Sleeping
Sleeping
import os | |
TMP_PATH = "/tmp" | |
os.environ["HF_HOME"] = os.path.join(TMP_PATH, "huggingface") | |
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"] + "/transformers" | |
os.environ["HF_DATASETS_CACHE"] = os.environ["HF_HOME"] + "/datasets" | |
os.environ["HF_METRICS_CACHE"] = os.environ["HF_HOME"] + "/metrics" | |
os.environ["MPLCONFIGDIR"] = os.path.join(TMP_PATH, "matplotlib") | |
os.environ["TORCH_HOME"] = os.path.join(TMP_PATH, "torch") | |
os.environ["XDG_CACHE_HOME"] = os.path.join(TMP_PATH, "xdg-cache") | |
os.environ["HOME"] = TMP_PATH | |
# Création des dossiers si nécessaires | |
for path in [ | |
os.environ["HF_HOME"], | |
os.environ["MPLCONFIGDIR"], | |
os.environ["TORCH_HOME"], | |
os.environ["XDG_CACHE_HOME"], | |
TMP_PATH | |
]: | |
os.makedirs(path, exist_ok=True) | |
# --- Imports principaux --- | |
from fastapi import FastAPI, UploadFile, File, HTTPException, Form | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import pipeline, set_seed | |
from deep_translator import GoogleTranslator | |
from TTS.api import TTS | |
import whisper | |
import io | |
import torch | |
import scipy.io.wavfile | |
import numpy as np | |
import traceback | |
# --- Init FastAPI --- | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --- Config --- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
TEMP_DIR = "/tmp" | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
# --- Modèles chargés dynamiquement --- | |
loaded_tts_models = {} | |
loaded_tts_lingala = None | |
asr_model = None | |
# --- Utilitaires --- | |
def load_tts_model(lang: str): | |
if lang not in loaded_tts_models: | |
model_id = f"OlameMend/mms-tts-{lang}" | |
try: | |
loaded_tts_models[lang] = pipeline("text-to-audio", model=model_id) | |
except Exception as e: | |
raise RuntimeError(f"Erreur lors du chargement du modèle TTS '{model_id}': {e}") | |
return loaded_tts_models[lang] | |
def load_asr_model(): | |
global asr_model | |
if asr_model is None: | |
asr_model = whisper.load_model("tiny") | |
return asr_model | |
def load_tts_lingala_model(): | |
global loaded_tts_lingala | |
if loaded_tts_lingala is None: | |
loaded_tts_lingala = TTS("tts_models/lin/openbible/vits") | |
return loaded_tts_lingala | |
def preprocess_text_tts(text: str) -> str: | |
return text.strip() | |
def generate_tts_audio(lang: str, text: str) -> io.BytesIO: | |
lang = lang.lower() | |
synthesizer = load_tts_model(lang) | |
processed_text = preprocess_text_tts(text) | |
set_seed(555) | |
speech = synthesizer(processed_text) | |
wav_io = io.BytesIO() | |
scipy.io.wavfile.write(wav_io, rate=speech["sampling_rate"], data=speech["audio"][0]) | |
wav_io.seek(0) | |
return wav_io | |
def speech_2_speech_ling(source_audio_path: str, lang: str) -> io.BytesIO: | |
asr = load_asr_model() | |
tts_lingala = load_tts_lingala_model() | |
result = asr.transcribe(source_audio_path, language=lang) | |
text = result["text"] | |
translated_text = GoogleTranslator(source="auto", target="ln").translate(text) | |
wav_io = io.BytesIO() | |
tts_lingala.tts_with_vc_to_file(text=translated_text, speaker_wav=source_audio_path, file_path=wav_io) | |
wav_io.seek(0) | |
return wav_io | |
# --- Endpoints --- | |
def greet_json(): | |
return {"Hello": "World!"} | |
async def api_tts( | |
lang: str = Form(...), | |
text: str = Form(None), | |
file: UploadFile = File(None) | |
): | |
try: | |
if file: | |
content = await file.read() | |
text = content.decode("utf-8") | |
if not text: | |
raise ValueError("Aucun texte fourni (champ texte ou fichier manquant).") | |
wav_io = generate_tts_audio(lang, text) | |
wav_io.seek(0) | |
return StreamingResponse(wav_io, media_type="audio/wav") | |
except ValueError as ve: | |
raise HTTPException(status_code=400, detail=str(ve)) | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=f"Erreur TTS : {str(e)}") | |
async def api_s2st( | |
source_audio: UploadFile = File(...), | |
lang: str = Form(...) | |
): | |
try: | |
source_path = os.path.join(TEMP_DIR, "source_" + source_audio.filename) | |
with open(source_path, "wb") as f: | |
f.write(await source_audio.read()) | |
wav_io = speech_2_speech_ling(source_path, lang) | |
wav_io.seek(0) | |
return StreamingResponse(wav_io, media_type="audio/wav") | |
except Exception as e: | |
traceback.print_exc() | |
raise HTTPException(status_code=500, detail=f"Erreur S2ST : {str(e)}") | |