Spaces:
Running
on
Zero
Running
on
Zero
| import spaces # <--- BẮT BUỘC DÒNG 1 | |
| import os | |
| import time | |
| import threading | |
| import pickle | |
| import hashlib | |
| import tempfile | |
| import numpy as np | |
| # Các thư viện khác | |
| import torch | |
| import soundfile as sf | |
| from pydub import AudioSegment | |
| import gradio as gr | |
| from vieneu_tts import VieNeuTTS | |
| print("⏳ Đang khởi động Server Gradio...") | |
| # --- 1. QUẢN LÝ MODEL (Lazy Loading) --- | |
| tts_model = None | |
| model_lock = threading.Lock() | |
| def get_tts_model(): | |
| """Chỉ tải model khi có người dùng gọi (Tiết kiệm tài nguyên khởi động)""" | |
| global tts_model | |
| with model_lock: | |
| if tts_model is None: | |
| print("📦 Đang khởi tạo model lần đầu (Lazy Load)...") | |
| # ZeroGPU yêu cầu khởi tạo model trên CPU hoặc trong hàm @spaces.GPU | |
| # Ở đây ta khởi tạo trên CPU cho an toàn | |
| tts_model = VieNeuTTS( | |
| backbone_repo="pnnbao-ump/VieNeu-TTS", | |
| backbone_device="cpu", | |
| codec_repo="neuphonic/neucodec", | |
| codec_device="cpu" | |
| ) | |
| print("✅ Model tải thành công!") | |
| return tts_model | |
| # --- 2. XỬ LÝ CACHE --- | |
| CACHE_DIR = "./reference_cache" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| reference_cache = {} | |
| reference_cache_lock = threading.Lock() | |
| def get_cache_path(cache_key): | |
| key_hash = hashlib.md5(cache_key.encode()).hexdigest() | |
| return os.path.join(CACHE_DIR, f"{key_hash}.pkl") | |
| def load_cache_from_disk(cache_key): | |
| cache_path = get_cache_path(cache_key) | |
| if os.path.exists(cache_path): | |
| try: | |
| with open(cache_path, 'rb') as f: return pickle.load(f) | |
| except: return None | |
| return None | |
| def save_cache_to_disk(cache_key, ref_codes): | |
| cache_path = get_cache_path(cache_key) | |
| try: | |
| with open(cache_path, 'wb') as f: pickle.dump(ref_codes, f) | |
| except Exception: pass | |
| # --- 3. DỮ LIỆU GIỌNG NÓI --- | |
| VOICE_SAMPLES = { | |
| "Tuyên (nam miền Bắc)": {"audio": "./sample/Tuyên (nam miền Bắc).wav", "text": "./sample/Tuyên (nam miền Bắc).txt"}, | |
| "Vĩnh (nam miền Nam)": {"audio": "./sample/Vĩnh (nam miền Nam).wav", "text": "./sample/Vĩnh (nam miền Nam).txt"}, | |
| "Bình (nam miền Bắc)": {"audio": "./sample/Bình (nam miền Bắc).wav", "text": "./sample/Bình (nam miền Bắc).txt"}, | |
| "Nguyên (nam miền Nam)": {"audio": "./sample/Nguyên (nam miền Nam).wav", "text": "./sample/Nguyên (nam miền Nam).txt"}, | |
| "Sơn (nam miền Nam)": {"audio": "./sample/Sơn (nam miền Nam).wav", "text": "./sample/Sơn (nam miền Nam).txt"}, | |
| "Đoan (nữ miền Nam)": {"audio": "./sample/Đoan (nữ miền Nam).wav", "text": "./sample/Đoan (nữ miền Nam).txt"}, | |
| "Ngọc (nữ miền Bắc)": {"audio": "./sample/Ngọc (nữ miền Bắc).wav", "text": "./sample/Ngọc (nữ miền Bắc).txt"}, | |
| "Ly (nữ miền Bắc)": {"audio": "./sample/Ly (nữ miền Bắc).wav", "text": "./sample/Ly (nữ miền Bắc).txt"}, | |
| "Dung (nữ miền Nam)": {"audio": "./sample/Dung (nữ miền Nam).wav", "text": "./sample/Dung (nữ miền Nam).txt"}, | |
| "Nhỏ Ngọt Ngào": {"audio": "./sample/Nhỏ Ngọt Ngào.wav", "text": "./sample/Nhỏ Ngọt Ngào.txt"}, | |
| } | |
| # --- 4. HÀM XỬ LÝ CHÍNH (GPU) --- | |
| def generate_speech(text, voice_choice, speed_factor): | |
| """ | |
| Hàm này sẽ được ZeroGPU cấp phát GPU khi chạy. | |
| Nó cũng đóng vai trò là API endpoint chính. | |
| """ | |
| start_time = time.time() | |
| # 1. Lấy Model (Tải nếu chưa có) | |
| tts = get_tts_model() | |
| # 2. Chuyển Model sang GPU (Chỉ làm trong hàm này) | |
| if torch.cuda.is_available(): | |
| try: | |
| if next(tts.backbone.parameters()).device.type != 'cuda': | |
| tts.backbone.to("cuda") | |
| tts.codec.to("cuda") | |
| except: pass | |
| # 3. Lấy thông tin giọng | |
| voice_info = VOICE_SAMPLES.get(voice_choice) | |
| if not voice_info: | |
| # Fallback nếu không tìm thấy giọng | |
| voice_choice = "Tuyên (nam miền Bắc)" | |
| voice_info = VOICE_SAMPLES[voice_choice] | |
| ref_audio_path = voice_info["audio"] | |
| ref_text_path = voice_info["text"] | |
| with open(ref_text_path, "r", encoding="utf-8") as f: | |
| ref_text_raw = f.read() | |
| # 4. Encode Reference (Có Cache) | |
| cache_key = f"preset:{voice_choice}" | |
| with reference_cache_lock: | |
| if cache_key in reference_cache: | |
| ref_codes = reference_cache[cache_key] | |
| if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available(): | |
| ref_codes = ref_codes.to("cuda") | |
| else: | |
| ref_codes = load_cache_from_disk(cache_key) | |
| if ref_codes is None: | |
| # Encode | |
| ref_codes = tts.encode_reference(ref_audio_path) | |
| save_cache_to_disk(cache_key, ref_codes.cpu() if isinstance(ref_codes, torch.Tensor) else ref_codes) | |
| if isinstance(ref_codes, torch.Tensor) and torch.cuda.is_available(): | |
| ref_codes = ref_codes.to("cuda") | |
| reference_cache[cache_key] = ref_codes | |
| # 5. Infer (Tạo giọng nói) | |
| wav = tts.infer(text, ref_codes, ref_text_raw) | |
| # 6. Xử lý tốc độ (Speed) | |
| if speed_factor != 1.0: | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| sf.write(tmp.name, wav, 24000) | |
| tmp_path = tmp.name | |
| sound = AudioSegment.from_wav(tmp_path) | |
| new_frame_rate = int(sound.frame_rate * speed_factor) | |
| sound_stretched = sound._spawn(sound.raw_data, overrides={'frame_rate': new_frame_rate}) | |
| sound_stretched = sound_stretched.set_frame_rate(24000) | |
| wav = np.array(sound_stretched.get_array_of_samples()).astype(np.float32) / 32768.0 | |
| if sound_stretched.channels == 2: | |
| wav = wav.reshape((-1, 2)).mean(axis=1) | |
| os.unlink(tmp_path) | |
| # 7. Lưu file kết quả | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| sf.write(tmp_file.name, wav, 24000) | |
| output_path = tmp_file.name | |
| return output_path, f"✅ Hoàn tất ({time.time() - start_time:.2f}s)" | |
| # --- 5. GIAO DIỆN GRADIO --- | |
| theme = gr.themes.Soft() | |
| css = ".container { max-width: 900px; margin: auto; }" | |
| with gr.Blocks(theme=theme, css=css, title="VieNeu-TTS") as demo: | |
| gr.Markdown("# 🎙️ VieNeu-TTS (ZeroGPU)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp_text = gr.Textbox(label="Văn bản", lines=3, value="Xin chào Việt Nam, đây là thử nghiệm giọng nói.") | |
| inp_voice = gr.Dropdown(list(VOICE_SAMPLES.keys()), value="Tuyên (nam miền Bắc)", label="Chọn giọng") | |
| inp_speed = gr.Slider(0.5, 2.0, value=1.0, label="Tốc độ") | |
| btn = gr.Button("Đọc ngay", variant="primary") | |
| with gr.Column(): | |
| out_audio = gr.Audio(label="Kết quả", autoplay=True) | |
| out_status = gr.Textbox(label="Trạng thái") | |
| # Map function vào button | |
| btn.click(generate_speech, [inp_text, inp_voice, inp_speed], [out_audio, out_status]) | |
| # --- 6. KHỞI CHẠY --- | |
| if __name__ == "__main__": | |
| # Dùng demo.launch() chuẩn để ZeroGPU nhận diện được | |
| demo.queue(default_concurrency_limit=40).launch(server_name="0.0.0.0", server_port=7860) |