FSub / gradio_app.py
nhantrungsp's picture
Update gradio_app.py
4a8f5a5 verified
import spaces # <--- BẮT BUỘC DÒNG 1
import os
import time
import threading
import pickle
import hashlib
import tempfile
import numpy as np
# Các thư viện khác
import torch
import soundfile as sf
from pydub import AudioSegment
import gradio as gr
from vieneu_tts import VieNeuTTS
print("⏳ Đang khởi động Server Gradio...")
# --- 1. QUẢN LÝ MODEL (Lazy Loading) ---
tts_model = None
model_lock = threading.Lock()
def get_tts_model():
"""Chỉ tải model khi có người dùng gọi (Tiết kiệm tài nguyên khởi động)"""
global tts_model
with model_lock:
if tts_model is None:
print("📦 Đang khởi tạo model lần đầu (Lazy Load)...")
# ZeroGPU yêu cầu khởi tạo model trên CPU hoặc trong hàm @spaces.GPU
# Ở đây ta khởi tạo trên CPU cho an toàn
tts_model = VieNeuTTS(
backbone_repo="pnnbao-ump/VieNeu-TTS",
backbone_device="cpu",
codec_repo="neuphonic/neucodec",
codec_device="cpu"
)
print("✅ Model tải thành công!")
return tts_model
# --- 2. XỬ LÝ CACHE ---
CACHE_DIR = "./reference_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
reference_cache = {}
reference_cache_lock = threading.Lock()
def get_cache_path(cache_key):
key_hash = hashlib.md5(cache_key.encode()).hexdigest()
return os.path.join(CACHE_DIR, f"{key_hash}.pkl")
def load_cache_from_disk(cache_key):
cache_path = get_cache_path(cache_key)
if os.path.exists(cache_path):
try:
with open(cache_path, 'rb') as f: return pickle.load(f)
except: return None
return None
def save_cache_to_disk(cache_key, ref_codes):
cache_path = get_cache_path(cache_key)
try:
with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f)
except Exception: pass
# --- 3. DỮ LIỆU GIỌNG NÓI ---
VOICE_SAMPLES = {
"Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"},
"Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"},
"Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"},
"Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"},
"Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"},
"Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"},
"Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"},
"Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"},
"Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"},
"Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"},
}
# --- 4. HÀM XỬ LÝ CHÍNH (GPU) ---
@spaces.GPU(duration=120)
def generate_speech(text, voice_choice, speed_factor):
"""
Hàm này sẽ được ZeroGPU cấp phát GPU khi chạy.
Nó cũng đóng vai trò là API endpoint chính.
"""
start_time = time.time()
# 1. Lấy Model (Tải nếu chưa có)
tts = get_tts_model()
# 2. Chuyển Model sang GPU (Chỉ làm trong hàm này)
if torch.cuda.is_available():
try:
if next(tts.backbone.parameters()).device.type != 'cuda':
tts.backbone.to("cuda")
tts.codec.to("cuda")
except: pass
# 3. Lấy thông tin giọng
voice_info = VOICE_SAMPLES.get(voice_choice)
if not voice_info:
# Fallback nếu không tìm thấy giọng
voice_choice = "Tuyên (nam miền Bắc)"
voice_info = VOICE_SAMPLES[voice_choice]
ref_audio_path = voice_info["audio"]
ref_text_path = voice_info["text"]
with open(ref_text_path, "r", encoding="utf-8") as f:
ref_text_raw = f.read()
# 4. Encode Reference (Có Cache)
cache_key = f"preset:{voice_choice}"
with reference_cache_lock:
if cache_key in reference_cache:
ref_codes = reference_cache[cache_key]
if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available():
ref_codes = ref_codes.to("cuda")
else:
ref_codes = load_cache_from_disk(cache_key)
if ref_codes is None:
# Encode
ref_codes = tts.encode_reference(ref_audio_path)
save_cache_to_disk(cache_key, ref_codes.cpu() if isinstance(ref_codes, torch.Tensor) else ref_codes)
if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available():
ref_codes = ref_codes.to("cuda")
reference_cache[cache_key] = ref_codes
# 5. Infer (Tạo giọng nói)
wav = tts.infer(text, ref_codes, ref_text_raw)
# 6. Xử lý tốc độ (Speed)
if speed_factor != 1.0:
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
sf.write(tmp.name, wav, 24000)
tmp_path = tmp.name
sound = AudioSegment.from_wav(tmp_path)
new_frame_rate = int(sound.frame_rate * speed_factor)
sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate})
sound_stretched = sound_stretched.set_frame_rate(24000)
wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0
if sound_stretched.channels == 2:
wav = wav.reshape((-1, 2)).mean(axis=1)
os.unlink(tmp_path)
# 7. Lưu file kết quả
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
sf.write(tmp_file.name, wav, 24000)
output_path = tmp_file.name
return output_path, f"✅ Hoàn tất ({time.time() - start_time:.2f}s)"
# --- 5. GIAO DIỆN GRADIO ---
theme = gr.themes.Soft()
css = ".container { max-width: 900px; margin: auto; }"
with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo:
gr.Markdown("# 🎙️ VieNeu-TTS (ZeroGPU)")
with gr.Row():
with gr.Column():
inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam, đây là thử nghiệm giọng nói.")
inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng")
inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ")
btn = gr.Button("Đọc ngay", variant="primary")
with gr.Column():
out_audio = gr.Audio(label="Kết quả", autoplay=True)
out_status = gr.Textbox(label="Trạng thái")
# Map function vào button
btn.click(generate_speech, [inp_text, inp_voice, inp_speed], [out_audio, out_status])
# --- 6. KHỞI CHẠY ---
if __name__ == "__main__":
# Dùng demo.launch() chuẩn để ZeroGPU nhận diện được
demo.queue(default_concurrency_limit=40).launch(server_name="0.0.0.0", server_port=7860)