OlameMend's picture
add audio stream
c3284a8
import os
TMP_PATH = "/tmp"
os.environ["HF_HOME"] = os.path.join(TMP_PATH, "huggingface")
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"] + "/transformers"
os.environ["HF_DATASETS_CACHE"] = os.environ["HF_HOME"] + "/datasets"
os.environ["HF_METRICS_CACHE"] = os.environ["HF_HOME"] + "/metrics"
os.environ["MPLCONFIGDIR"] = os.path.join(TMP_PATH, "matplotlib")
os.environ["TORCH_HOME"] = os.path.join(TMP_PATH, "torch")
os.environ["XDG_CACHE_HOME"] = os.path.join(TMP_PATH, "xdg-cache")
os.environ["HOME"] = TMP_PATH
# Création des dossiers si nécessaires
for path in [
os.environ["HF_HOME"],
os.environ["MPLCONFIGDIR"],
os.environ["TORCH_HOME"],
os.environ["XDG_CACHE_HOME"],
TMP_PATH
]:
os.makedirs(path, exist_ok=True)
# --- Imports principaux ---
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline, set_seed
from deep_translator import GoogleTranslator
from TTS.api import TTS
import whisper
import io
import torch
import scipy.io.wavfile
import numpy as np
import traceback
# --- Init FastAPI ---
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Config ---
device = "cuda" if torch.cuda.is_available() else "cpu"
TEMP_DIR = "/tmp"
os.makedirs(TEMP_DIR, exist_ok=True)
# --- Modèles chargés dynamiquement ---
loaded_tts_models = {}
loaded_tts_lingala = None
asr_model = None
# --- Utilitaires ---
def load_tts_model(lang: str):
if lang not in loaded_tts_models:
model_id = f"OlameMend/mms-tts-{lang}"
try:
loaded_tts_models[lang] = pipeline("text-to-audio", model=model_id)
except Exception as e:
raise RuntimeError(f"Erreur lors du chargement du modèle TTS '{model_id}': {e}")
return loaded_tts_models[lang]
def load_asr_model():
global asr_model
if asr_model is None:
asr_model = whisper.load_model("tiny")
return asr_model
def load_tts_lingala_model():
global loaded_tts_lingala
if loaded_tts_lingala is None:
loaded_tts_lingala = TTS("tts_models/lin/openbible/vits")
return loaded_tts_lingala
def preprocess_text_tts(text: str) -> str:
return text.strip()
def generate_tts_audio(lang: str, text: str) -> io.BytesIO:
lang = lang.lower()
synthesizer = load_tts_model(lang)
processed_text = preprocess_text_tts(text)
set_seed(555)
speech = synthesizer(processed_text)
wav_io = io.BytesIO()
scipy.io.wavfile.write(wav_io, rate=speech["sampling_rate"], data=speech["audio"][0])
wav_io.seek(0)
return wav_io
def speech_2_speech_ling(source_audio_path: str, lang: str) -> io.BytesIO:
asr = load_asr_model()
tts_lingala = load_tts_lingala_model()
result = asr.transcribe(source_audio_path, language=lang)
text = result["text"]
translated_text = GoogleTranslator(source="auto", target="ln").translate(text)
wav_io = io.BytesIO()
tts_lingala.tts_with_vc_to_file(text=translated_text, speaker_wav=source_audio_path, file_path=wav_io)
wav_io.seek(0)
return wav_io
# --- Endpoints ---
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/tts/")
async def api_tts(
lang: str = Form(...),
text: str = Form(None),
file: UploadFile = File(None)
):
try:
if file:
content = await file.read()
text = content.decode("utf-8")
if not text:
raise ValueError("Aucun texte fourni (champ texte ou fichier manquant).")
wav_io = generate_tts_audio(lang, text)
wav_io.seek(0)
return StreamingResponse(wav_io, media_type="audio/wav")
except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Erreur TTS : {str(e)}")
@app.post("/speech-to-speech/")
async def api_s2st(
source_audio: UploadFile = File(...),
lang: str = Form(...)
):
try:
source_path = os.path.join(TEMP_DIR, "source_" + source_audio.filename)
with open(source_path, "wb") as f:
f.write(await source_audio.read())
wav_io = speech_2_speech_ling(source_path, lang)
wav_io.seek(0)
return StreamingResponse(wav_io, media_type="audio/wav")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Erreur S2ST : {str(e)}")