TTS / app.py
randusertry's picture
Update app.py
c5fe5fc verified
from fastapi import FastAPI, Response, HTTPException
from fastapi.responses import StreamingResponse
import numpy as np
from piper import PiperVoice
import sherpa_onnx
import base64
import io
import os
import httpx
import wave
from pydantic import BaseModel
from typing import Optional, Literal
app = FastAPI(title="TTS App for my projects")
# Path where models will be stored in the container
MODEL_DIR = "./models"
os.makedirs(MODEL_DIR, exist_ok=True)
VOICE_MAP = {
# Gendered Languages (Male and Female models available)
"en": {"gendered": True, "male": "en_GB-alan-medium", "female": "en_GB-semaine-medium"},
"es": {"gendered": True, "male": "es_ES-sharvard-medium", "female": "es_ES-davefx-medium"},
"fr": {"gendered": True, "male": "fr_FR-upmc-medium", "female": "fr_FR-siwis-medium"},
"de": {"gendered": True, "male": "de_DE-thorsten-medium", "female": "de_DE-kerstin-low"},
"it": {"gendered": True, "male": "it_IT-riccardo-x_low", "female": "it_IT-paola-medium"},
"pl": {"gendered": True, "male": "pl_PL-darkman-medium", "female": "pl_PL-gosia-medium"},
"uk": {"gendered": True, "male": "uk_UA-ukrainian_tts-medium", "female": "uk_UA-lada-x_low"},
"nl": {"gendered": True, "male": "nl_NL-ronnie-medium", "female": "nl_NL-mls-medium"},
"eu": {"gendered": True, "male": "eu_ES-antton-medium", "female": "eu_ES-maider-medium"},
# Non-Gendered / Single-Voice Languages (Default model used)
"bg": {"gendered": False, "default": "bg_BG-dimitar-medium"},
"ca": {"gendered": False, "default": "ca_ES-upc_ona-medium"},
"cs": {"gendered": False, "default": "cs_CZ-jirka-medium"},
"da": {"gendered": False, "default": "da_DK-talesyntese-medium"},
"fi": {"gendered": False, "default": "fi_FI-harri-medium"},
"el": {"gendered": False, "default": "el_GR-rapunzelina-low"},
"hu": {"gendered": False, "default": "hu_HU-anna-medium"},
"is": {"gendered": False, "default": "is_IS-ugla-medium"},
"lv": {"gendered": False, "default": "lv_LV-aivars-medium"},
"ro": {"gendered": False, "default": "ro_RO-mihai-medium"},
"sk": {"gendered": False, "default": "sk_SK-lili-medium"},
"sl": {"gendered": False, "default": "sl_SI-artur-medium"},
"sv": {"gendered": False, "default": "sv_SE-lisa-medium"},
"cy": {"gendered": False, "default": "cy_GB-gwryw_gogleddol-medium"}
}
IRISH_MAP = {
"Donegal": {"gendered":True, "male": "ga_UL_doc_piper", "female":"ga_UL_anb_piper"},
"Kerry": {"gendered":True, "male": "ga_MU_cmg_piper", "female":"ga_MU_nnc_piper"},
"Ring": {"gendered":False,"default":"ga_MU_ar_fnm_piper"},
"Connemara": {"gendered":False,"default":"ga_CO_snc_piper"}
}
# Cache for loaded models to avoid re-loading from disk every request
loaded_voices = {}
def get_voice(model_name: str):
if model_name not in loaded_voices:
# Assumes model files (onnx and json) are in MODEL_DIR
model_path = os.path.join(MODEL_DIR, f"{model_name}.onnx")
config_path = f"{model_path}.json"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model {model_name} not found.")
loaded_voices[model_name] = PiperVoice.load(model_path, config_path)
return loaded_voices[model_name]
class TTSRequest(BaseModel):
text: str
language: str
gender: Literal["male","female"] = "male"
dialect: Optional[Literal["Kerry", "Donegal", "Ring", "Connemara"]] = None
@app.post("/tts/piper")
async def tts_post(request: TTSRequest):
try:
lang_code = request.language.lower()
lang_entry = VOICE_MAP.get(lang_code)
if not lang_entry:
raise HTTPException(status_code=400, detail=f"Language '{lang_code}' not supported.")
# Determine model name
if lang_entry["gendered"]:
model_name = lang_entry.get(request.gender.lower(), lang_entry["male"])
else:
model_name = lang_entry["default"]
voice = get_voice(model_name)
# Create an in-memory buffer for the WAV file
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # 16-bit audio
wav_file.setframerate(voice.config.sample_rate)
for chunk in voice.synthesize(request.text):
# Convert the audio float array to 16-bit PCM
audio_int16 = (chunk.audio_float_array * 32767).astype("int16")
# Write the PCM data to the WAV file
wav_file.writeframes(audio_int16.tobytes())
wav_buffer.seek(0)
return Response(content=wav_buffer.getvalue(), media_type="audio/wav")
except Exception as e:
print(f"Error during TTS: {e}")
raise HTTPException(status_code=500, detail=str(e))
ABAIR_URL = "https://synthesis.abair.ie/api/synthesise"
@app.post("/tts/irish")
async def get_irish_tts(request: TTSRequest):
"""
Fetches Irish speech from the new ABAIR synthesis API.
"""
dialect = request.dialect or "Donegal"
# 1. Determine the correct voice string
entry = IRISH_MAP.get(dialect, IRISH_MAP["Donegal"])
if entry.get("gendered"):
voice = entry.get(request.gender.lower(), entry["male"])
else:
voice = entry["default"]
# 2. Set up the request as per your working example
params = {
"input": request.text,
"voice": voice,
"normalise": "true",
"speed": 0.9
}
headers = {
"Origin": "https://abair.ie",
"Referer": "https://abair.ie/",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
"Accept": "*/*"
}
async with httpx.AsyncClient() as client:
try:
# Note: ABAIR expects a GET request for this specific endpoint
response = await client.get(ABAIR_URL, params=params, headers=headers, timeout=15.0)
if response.status_code != 200:
print(f"ABAIR Error: {response.status_code} - {response.text}")
raise HTTPException(status_code=502, detail=f"ABAIR service error: {response.status_code}")
data = response.json()
# 3. Handle Base64 decoding
if "audioContent" not in data:
raise HTTPException(status_code=500, detail="Invalid response format from ABAIR")
audio_bytes = base64.b64decode(data["audioContent"])
# 4. Return the decoded WAV binary
return Response(content=audio_bytes, media_type="audio/wav")
except httpx.RequestError as exc:
raise HTTPException(status_code=503, detail=f"Could not connect to ABAIR: {exc}")
except Exception as e:
print(f"Internal Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
breton_engine = None
def get_breton_engine():
global breton_engine
if breton_engine is None:
# 1. Specific VITS model settings
vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
model=os.path.join(MODEL_DIR, "breton-model.onnx"),
tokens=os.path.join(MODEL_DIR, "breton-tokens.txt"),
data_dir="",
noise_scale=0.667,
noise_scale_w=0.8,
length_scale=1.0,
)
# 2. Wrap VITS into the Model Config
model_config = sherpa_onnx.OfflineTtsModelConfig(
vits=vits_config,
num_threads=1,
debug=False,
provider="cpu",
)
# 3. Wrap everything into the Top-Level OfflineTtsConfig (The missing step!)
full_config = sherpa_onnx.OfflineTtsConfig(
model=model_config,
# rule_fsts is required for some models, empty string is fine here
rule_fsts="",
max_num_sentences=1,
)
# Now pass the full_config to the constructor
breton_engine = sherpa_onnx.OfflineTts(full_config)
return breton_engine
@app.post("/tts/breton")
async def get_breton_tts(request: TTSRequest):
try:
engine = get_breton_engine()
sid = 0 if request.gender.lower() == "female" else 1
# 1. Generate audio (this returns an object with a .samples list)
audio = engine.generate(request.text, sid=sid)
# 2. Convert the Python list to a NumPy array
samples_array = np.array(audio.samples, dtype=np.float32)
wav_buffer = io.BytesIO()
with wave.open(wav_buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(audio.sample_rate)
# 3. Now .astype("int16") will work perfectly on the NumPy array
audio_int16 = (samples_array * 32767).astype("int16")
wav_file.writeframes(audio_int16.tobytes())
wav_buffer.seek(0)
return Response(content=wav_buffer.getvalue(), media_type="audio/wav")
except Exception as e:
print(f"Breton TTS Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
def home():
return {"status": "Piper TTS is running"}
@app.get("/")
def home():
# List all files in the models directory
try:
files = os.listdir(MODEL_DIR)
except Exception as e:
files = [f"Error reading directory: {str(e)}"]
return {
"message": "Piper TTS API is running",
"models_in_folder": files,
"supported_languages": [v for v in list(VOICE_MAP.keys())]+["ga","br"]
}