caash_api / app.py
Somalitts's picture
Create app.py
95b6972 verified
import os
import re
import uuid
import torch
import torchaudio
import soundfile as sf
from fastapi import FastAPI
from fastapi.responses import FileResponse
from pydantic import BaseModel
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from speechbrain.inference.speaker import EncoderClassifier
app = FastAPI()
device = "cuda" if torch.cuda.is_available() else "cpu"
CACHE_DIR = "/tmp/hf-cache"
# Load models
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts", cache_dir=CACHE_DIR)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan", cache_dir=CACHE_DIR).to(device)
model_male = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/5aad", cache_dir=CACHE_DIR).to(device)
model_female = SpeechT5ForTextToSpeech.from_pretrained("Somalitts/8aad", cache_dir=CACHE_DIR).to(device)
# Speaker encoder
speaker_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-xvect-voxceleb",
run_opts={"device": device},
savedir="/tmp/spk_model"
)
# Load speaker embeddings
def get_embedding(wav_path, pt_path):
if os.path.exists(pt_path):
return torch.load(pt_path).to(device)
audio, sr = torchaudio.load(wav_path)
audio = torchaudio.functional.resample(audio, sr, 16000).mean(dim=0).unsqueeze(0).to(device)
with torch.no_grad():
emb = speaker_model.encode_batch(audio)
emb = torch.nn.functional.normalize(emb, dim=2).squeeze()
torch.save(emb.cpu(), pt_path)
return emb
embedding_male = get_embedding("Hussein.wav", "/tmp/male_embedding.pt")
embedding_female = get_embedding("caasho.wav", "/tmp/female_embedding.pt")
# Text normalization
number_words = {
0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
100: "boqol", 1000: "kun"
}
def number_to_words(n):
if n < 20:
return number_words.get(n, str(n))
elif n < 100:
tens, unit = divmod(n, 10)
return number_words[tens * 10] + (" " + number_words[unit] if unit else "")
elif n < 1000:
hundreds, rem = divmod(n, 100)
return (number_words[hundreds] + " boqol" if hundreds > 1 else "boqol") + (" " + number_to_words(rem) if rem else "")
elif n < 1_000_000:
th, rem = divmod(n, 1000)
return (number_to_words(th) + " kun") + (" " + number_to_words(rem) if rem else "")
else:
return str(n)
def replace_numbers_with_words(text):
return re.sub(r'\b\d+\b', lambda m: number_to_words(int(m.group())), text)
def normalize_text(text):
text = text.lower()
text = replace_numbers_with_words(text)
text = re.sub(r'[^\w\s]', '', text)
return text
# API request schema
class TTSRequest(BaseModel):
text: str
voice: str # "Male" or "Female"
@app.post("/speak")
def speak(payload: TTSRequest):
clean_text = normalize_text(payload.text)
inputs = processor(text=clean_text, return_tensors="pt").to(device)
model = model_male if payload.voice.lower() == "male" else model_female
embedding = embedding_male if payload.voice.lower() == "male" else embedding_female
with torch.no_grad():
waveform = model.generate_speech(inputs["input_ids"], embedding.unsqueeze(0), vocoder=vocoder)
out_path = f"/tmp/{uuid.uuid4().hex}.wav"
sf.write(out_path, waveform.cpu().numpy(), 16000)
return FileResponse(out_path, media_type="audio/wav", filename="voice.wav")