Spaces:
Runtime error
Runtime error
File size: 7,841 Bytes
97d03bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
import uuid
import torch
import torchaudio
import base64
from io import BytesIO
from transformers import AutoModelForCausalLM
import sys
import subprocess
from datetime import datetime, timedelta
app = FastAPI(title="Nigerian TTS API")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, set this to your Next.js domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize necessary directories
os.makedirs("audio_files", exist_ok=True)
os.makedirs("models", exist_ok=True)
# Check if YarnGPT is installed, if not install it
try:
import yarngpt
from yarngpt.audiotokenizer import AudioTokenizerV2
except ImportError:
print("Installing YarnGPT and dependencies...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/saheedniyi02/yarngpt.git"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "outetts", "uroman", "transformers", "torchaudio"])
from yarngpt.audiotokenizer import AudioTokenizerV2
# Model configuration
tokenizer_path = "saheedniyi/YarnGPT2"
# Check if model files exist, if not download them
wav_tokenizer_config_path = "./models/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
wav_tokenizer_model_path = "./models/wavtokenizer_large_speech_320_24k.ckpt"
if not os.path.exists(wav_tokenizer_config_path):
print("Downloading model config file...")
subprocess.check_call([
"wget", "-O", wav_tokenizer_config_path,
"https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml"
])
if not os.path.exists(wav_tokenizer_model_path):
print("Downloading model checkpoint file...")
subprocess.check_call([
"wget", "-O", wav_tokenizer_model_path,
"https://drive.google.com/uc?id=1-ASeEkrn4HY49yZWHTASgfGFNXdVnLTt&export=download"
])
print("Loading YarnGPT model and tokenizer...")
audio_tokenizer = AudioTokenizerV2(
tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path
)
model = AutoModelForCausalLM.from_pretrained(tokenizer_path, torch_dtype="auto").to(audio_tokenizer.device)
print("Model loaded successfully!")
# Available voices and languages
AVAILABLE_VOICES = {
"female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"],
"male": ["jude", "tayo", "umar", "osagie", "onye", "emma"]
}
AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"]
# Input validation model
class TTSRequest(BaseModel):
text: str
language: str = "english"
voice: str = "idera"
# Output model with base64-encoded audio
class TTSResponse(BaseModel):
audio_base64: str # Base64-encoded audio data
audio_url: str # Keep for backward compatibility
text: str
voice: str
language: str
@app.get("/")
async def root():
"""API health check and info"""
return {
"status": "ok",
"message": "Nigerian TTS API is running",
"available_languages": AVAILABLE_LANGUAGES,
"available_voices": AVAILABLE_VOICES
}
@app.post("/tts", response_model=TTSResponse)
async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks):
"""Convert text to Nigerian-accented speech"""
# Validate inputs
if request.language not in AVAILABLE_LANGUAGES:
raise HTTPException(status_code=400, detail=f"Language must be one of {AVAILABLE_LANGUAGES}")
all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"]
if request.voice not in all_voices:
raise HTTPException(status_code=400, detail=f"Voice must be one of {all_voices}")
# Generate unique filename
audio_id = str(uuid.uuid4())
output_path = f"audio_files/{audio_id}.wav"
try:
# Create prompt and generate audio
prompt = audio_tokenizer.create_prompt(request.text, lang=request.language, speaker_name=request.voice)
input_ids = audio_tokenizer.tokenize_prompt(prompt)
output = model.generate(
input_ids=input_ids,
temperature=0.1,
repetition_penalty=1.1,
max_length=4000,
)
codes = audio_tokenizer.get_codes(output)
audio = audio_tokenizer.get_audio(codes)
# Save audio file
torchaudio.save(output_path, audio, sample_rate=24000)
# Read the file and encode as base64
with open(output_path, "rb") as audio_file:
audio_bytes = audio_file.read()
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
# Clean up old files after a while
background_tasks.add_task(cleanup_old_files)
return TTSResponse(
audio_base64=audio_base64,
audio_url=f"/audio/{audio_id}.wav", # Keep for compatibility
text=request.text,
voice=request.voice,
language=request.language
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
# File serving endpoint for direct audio access
@app.get("/audio/{filename}")
async def get_audio(filename: str):
file_path = f"audio_files/{filename}"
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="Audio file not found")
def iterfile():
with open(file_path, "rb") as audio_file:
yield from audio_file
return StreamingResponse(iterfile(), media_type="audio/wav")
# Endpoint to stream audio directly from base64 (useful for debugging)
@app.post("/stream-audio")
async def stream_audio(request: TTSRequest):
"""Stream audio directly without saving to disk"""
try:
# Create prompt and generate audio
prompt = audio_tokenizer.create_prompt(request.text, lang=request.language, speaker_name=request.voice)
input_ids = audio_tokenizer.tokenize_prompt(prompt)
output = model.generate(
input_ids=input_ids,
temperature=0.1,
repetition_penalty=1.1,
max_length=4000,
)
codes = audio_tokenizer.get_codes(output)
audio = audio_tokenizer.get_audio(codes)
# Create BytesIO object
buffer = BytesIO()
torchaudio.save(buffer, audio, sample_rate=24000, format="wav")
buffer.seek(0)
return StreamingResponse(buffer, media_type="audio/wav")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}")
# Cleanup function to remove old files
def cleanup_old_files():
"""Delete audio files older than 6 hours to manage disk space"""
try:
now = datetime.now()
audio_dir = "audio_files"
for filename in os.listdir(audio_dir):
if not filename.endswith(".wav"):
continue
file_path = os.path.join(audio_dir, filename)
file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path))
# Delete files older than 6 hours
if now - file_mod_time > timedelta(hours=6):
os.remove(file_path)
print(f"Deleted old audio file: {filename}")
except Exception as e:
print(f"Error cleaning up old files: {e}")
# For running locally with uvicorn
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port) |