Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# app.py — ESpeech-TTS с поддержкой ZeroGPU (Hugging Face Spaces) | |
# ----------------- ZeroGPU / spaces импорт + fallback ----------------- | |
# В среде ZeroGPU доступен пакет `spaces`, который предоставляет декоратор GPU. | |
# Для локальной отладки мы делаем fallback — noop-декоратор. | |
import spaces # provided by Spaces/ZeroGPU environment | |
GPU_DECORATOR = spaces.GPU | |
print("spaces module available — ZeroGPU features enabled") | |
import os | |
import gc | |
import json | |
import tempfile | |
import traceback | |
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import soundfile as sf | |
import torch | |
import torchaudio | |
from huggingface_hub import hf_hub_download | |
# Ваши зависимости / локальные импорты | |
from ruaccent import RUAccent | |
import onnx_asr | |
from f5_tts.infer.utils_infer import ( | |
infer_process, | |
load_model, | |
load_vocoder, | |
preprocess_ref_audio_text, | |
remove_silence_for_generated_wav, | |
save_spectrogram, | |
tempfile_kwargs, | |
) | |
from f5_tts.model import DiT | |
# Явно включаем ленивый режим кеширования примеров, чтобы примеры не запускались на старте | |
# (ZeroGPU по умолчанию использует lazy — делаем это явным). | |
os.environ.setdefault("GRADIO_CACHE_MODE", "lazy") | |
os.environ.setdefault("GRADIO_CACHE_EXAMPLES", "lazy") | |
# ----------------- HF hub / модели ----------------- | |
# Настройте репозитории и имена файлов в Hub под себя | |
MODEL_REPOS = { | |
"ESpeech-TTS-1 [RL] V2": { | |
"repo_id": "ESpeech/ESpeech-TTS-1_RL-V2", | |
"filename": "espeech_tts_rlv2.pt", | |
}, | |
"ESpeech-TTS-1 [RL] V1": { | |
"repo_id": "ESpeech/ESpeech-TTS-1_RL-V1", | |
"filename": "espeech_tts_rlv1.pt", | |
}, | |
"ESpeech-TTS-1 [SFT] 95K": { | |
"repo_id": "ESpeech/ESpeech-TTS-1_SFT-95K", | |
"filename": "espeech_tts_95k.pt", | |
}, | |
"ESpeech-TTS-1 [SFT] 265K": { | |
"repo_id": "ESpeech/ESpeech-TTS-1_SFT-256K", | |
"filename": "espeech_tts_256k.pt", | |
}, | |
"ESpeech-TTS-1 PODCASTER [SFT]": { | |
"repo_id": "ESpeech/ESpeech-TTS-1_podcaster", | |
"filename": "espeech_tts_podcaster.pt", | |
}, | |
} | |
# где лежит общий vocab в Hub | |
VOCAB_REPO = "ESpeech/ESpeech-TTS-1_podcaster" | |
VOCAB_FILENAME = "vocab.txt" | |
# токен, если репозитории приватные (в Spaces обычно берут из Secrets) | |
HF_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or None | |
MODEL_CFG = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | |
# кэш локальных путей после hf_hub_download | |
_cached_local_paths = {} | |
loaded_models = {} # хранит объекты моделей в памяти (по имени выбора) | |
# Пример текста для демонстрации | |
EXAMPLE_TEXT = "Экспериментальный центр напоминает вам о том, что кубы не умеют разговаривать. В случае, если грузовой куб все же заговорит, центр настоятельно рекомендует вам игнорировать его советы." | |
EXAMPLE_REF_AUDIO = "ref/example.mp3" | |
# ----------------- Вспомогательные функции HF ----------------- | |
def hf_download_file(repo_id: str, filename: str, token: str = None): | |
try: | |
print(f"hf_hub_download: {repo_id}/{filename}") | |
p = hf_hub_download(repo_id=repo_id, filename=filename, token=token, repo_type="model") | |
print(" ->", p) | |
return p | |
except Exception as e: | |
print("Download error:", e) | |
raise | |
def get_vocab_path(): | |
key = f"{VOCAB_REPO}::{VOCAB_FILENAME}" | |
if key in _cached_local_paths and Path(_cached_local_paths[key]).exists(): | |
return _cached_local_paths[key] | |
p = hf_download_file(VOCAB_REPO, VOCAB_FILENAME, token=HF_TOKEN) | |
_cached_local_paths[key] = p | |
return p | |
def get_model_local_path(choice: str): | |
if choice not in MODEL_REPOS: | |
raise KeyError("Unknown model choice: " + repr(choice)) | |
repo = MODEL_REPOS[choice] | |
key = f"{repo['repo_id']}::{repo['filename']}" | |
if key in _cached_local_paths and Path(_cached_local_paths[key]).exists(): | |
return _cached_local_paths[key] | |
p = hf_download_file(repo["repo_id"], repo["filename"], token=HF_TOKEN) | |
_cached_local_paths[key] = p | |
return p | |
def load_model_if_needed(choice: str): | |
""" | |
Лениво: если модель уже загружена в loaded_models — вернуть. | |
Иначе скачать файл (если нужно) и вызвать вашу load_model (возвращает PyTorch модель в CPU). | |
Не переводим на GPU здесь — это делается внутри GPU-декорированной функции. | |
""" | |
if choice in loaded_models: | |
return loaded_models[choice] | |
model_file = get_model_local_path(choice) | |
vocab_file = get_vocab_path() | |
print(f"Loading model into CPU memory: {choice} from {model_file}") | |
model = load_model(DiT, MODEL_CFG, model_file, vocab_file=vocab_file) | |
loaded_models[choice] = model | |
return model | |
# ----------------- общие ресурсы (vocoder, RUAccent, ASR) ----------------- | |
print("Loading RUAccent...") | |
accentizer = RUAccent() | |
accentizer.load(omograph_model_size='turbo3.1', use_dictionary=True, tiny_mode=False) | |
print("RUAccent loaded.") | |
print("Loading ASR (onnx) ...") | |
asr_model = onnx_asr.load_model("nemo-fastconformer-ru-rnnt") | |
print("ASR ready.") | |
print("Loading vocoder (CPU) ...") | |
vocoder = load_vocoder() | |
print("Vocoder loaded.") | |
# ----------------- Функция для обработки текста с учетом "+" ----------------- | |
def process_text_with_accent(text, accentizer): | |
""" | |
Обрабатывает текст через RUAccent, если в нем нет символа '+'. | |
Если есть '+' - пользователь сам проставил ударения, не трогаем. | |
""" | |
if not text or not text.strip(): | |
return text | |
if '+' in text: | |
# Пользователь сам проставил ударения | |
return text | |
else: | |
# Прогоняем через RUAccent | |
return accentizer.process_all(text) | |
# ----------------- Функция для обработки текста без синтеза ----------------- | |
def process_texts_only(ref_text, gen_text): | |
""" | |
Обрабатывает только тексты через RUAccent, не делая синтез. | |
Возвращает обработанные тексты для обновления полей ввода. | |
""" | |
processed_ref_text = process_text_with_accent(ref_text, accentizer) | |
processed_gen_text = process_text_with_accent(gen_text, accentizer) | |
return processed_ref_text, processed_gen_text | |
# ----------------- Основная функция синтеза (GPU-aware) ----------------- | |
# Декорируем synthesize, чтобы при вызове Space выделял GPU (если доступно). | |
# duration — сколько секунд просим GPU (адаптируйте под ваш инференс). | |
def synthesize( | |
model_choice, | |
ref_audio, | |
ref_text, | |
gen_text, | |
remove_silence, | |
seed, | |
cross_fade_duration=0.15, | |
nfe_step=32, | |
speed=1.0, | |
): | |
""" | |
Эта функция будет выполняться с выделенным GPU в ZeroGPU Spaces. | |
Подход: | |
- лениво загружаем модель (в CPU) если надо | |
- переносим модель и (если требуется) vocoder на cuda | |
- делаем infer | |
- возвращаем модели на CPU и очищаем cuda cache | |
""" | |
if not ref_audio: | |
gr.Warning("Please provide reference audio.") | |
return None, None, ref_text, gen_text | |
if seed is None or seed < 0 or seed > 2**31 - 1: | |
seed = np.random.randint(0, 2**31 - 1) | |
torch.manual_seed(int(seed)) | |
if not gen_text or not gen_text.strip(): | |
gr.Warning("Please enter text to generate.") | |
return None, None, ref_text, gen_text | |
# ASR если нужно | |
if not ref_text or not ref_text.strip(): | |
gr.Info("Reference text is empty. Running ASR to transcribe reference audio...") | |
try: | |
waveform, sample_rate = torchaudio.load(ref_audio) | |
waveform = waveform.numpy() | |
if waveform.dtype == np.int16: | |
waveform = waveform / 2**15 | |
elif waveform.dtype == np.int32: | |
waveform = waveform / 2**31 | |
if waveform.ndim == 2: | |
waveform = waveform.mean(axis=0) | |
transcribed_text = asr_model.recognize(waveform, sample_rate=sample_rate) | |
ref_text = transcribed_text | |
gr.Info(f"ASR transcription: {ref_text}") | |
except Exception as e: | |
gr.Warning(f"ASR failed: {e}") | |
return None, None, ref_text, gen_text | |
# Акцентирование с учетом наличия символа "+" | |
processed_ref_text = process_text_with_accent(ref_text, accentizer) | |
processed_gen_text = process_text_with_accent(gen_text, accentizer) | |
# Ленивая загрузка модели (в CPU) | |
try: | |
model = load_model_if_needed(model_choice) | |
except Exception as e: | |
gr.Warning(f"Failed to download/load model {model_choice}: {e}") | |
return None, None, processed_ref_text, processed_gen_text | |
# Определяем устройство (в ZeroGPU внутри декоратора должен быть доступен CUDA) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
moved_to_cuda = [] | |
try: | |
# Переносим модель на GPU (если есть) | |
if device.type == "cuda": | |
try: | |
model.to(device) | |
moved_to_cuda.append(("model", model)) | |
# если vocoder использует torch — переносим его тоже | |
try: | |
vocoder.to(device) | |
moved_to_cuda.append(("vocoder", vocoder)) | |
except Exception: | |
# если vocoder не torch-объект — ок | |
pass | |
except Exception as e: | |
print("Warning: failed to move model/vocoder to cuda:", e) | |
# Препроцессинг рефа (оно ожидает путь/файл) | |
try: | |
ref_audio_proc, processed_ref_text_final = preprocess_ref_audio_text( | |
ref_audio, | |
processed_ref_text, | |
show_info=gr.Info | |
) | |
except Exception as e: | |
gr.Warning(f"Preprocess failed: {e}") | |
traceback.print_exc() | |
return None, None, processed_ref_text, processed_gen_text | |
# Инференс (предполагается, что infer_process корректно работает и на GPU) | |
try: | |
final_wave, final_sample_rate, combined_spectrogram = infer_process( | |
ref_audio_proc, | |
processed_ref_text_final, | |
processed_gen_text, | |
model, | |
vocoder, | |
cross_fade_duration=cross_fade_duration, | |
nfe_step=nfe_step, | |
speed=speed, | |
show_info=gr.Info, | |
progress=gr.Progress(), | |
) | |
except Exception as e: | |
gr.Warning(f"Infer failed: {e}") | |
traceback.print_exc() | |
return None, None, processed_ref_text, processed_gen_text | |
# Удаление тишин (на CPU) | |
if remove_silence: | |
try: | |
with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f: | |
temp_path = f.name | |
sf.write(temp_path, final_wave, final_sample_rate) | |
remove_silence_for_generated_wav(temp_path) | |
final_wave_tensor, _ = torchaudio.load(temp_path) | |
final_wave = final_wave_tensor.squeeze().cpu().numpy() | |
except Exception as e: | |
print("Remove silence failed:", e) | |
# Сохраняем спектрограмму | |
try: | |
with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram: | |
spectrogram_path = tmp_spectrogram.name | |
save_spectrogram(combined_spectrogram, spectrogram_path) | |
except Exception as e: | |
print("Save spectrogram failed:", e) | |
spectrogram_path = None | |
return (final_sample_rate, final_wave), spectrogram_path, processed_ref_text_final, processed_gen_text | |
finally: | |
# Переносим всё обратно на CPU и очищаем GPU память | |
if device.type == "cuda": | |
try: | |
for name, obj in moved_to_cuda: | |
try: | |
obj.to("cpu") | |
except Exception: | |
pass | |
torch.cuda.empty_cache() | |
# немножко сборки мусора | |
gc.collect() | |
except Exception as e: | |
print("Warning during cuda cleanup:", e) | |
# ----------------- Gradio UI (как у вас) ----------------- | |
with gr.Blocks(title="ESpeech-TTS") as app: | |
gr.Markdown("# ESpeech-TTS") | |
gr.Markdown("Подробнее см. на https://huggingface.co/ESpeech") | |
gr.Markdown("💡 **Совет:** Добавьте символ '+' в тексте, чтобы указать пользовательское ударение (например, 'прив+ет'). Текст с '+' не будет обрабатываться RUAccent.") | |
gr.Markdown("❌ **Совет:** Референс должен быть не БОЛЕЕ 12-ти секунд. Иначе модель сломается.") | |
# Описание моделей на русском языке | |
gr.Markdown(""" | |
## 📋 Описание моделей: | |
- **ESpeech-TTS-1 [RL] V1** - Первая версия модели с RL | |
- **ESpeech-TTS-1 [RL] V2** - Вторая версия модели с RL | |
- **ESpeech-TTS-1 PODCASTER [SFT]** - Модель обученная только на подкастах, лучше генерирует спонтанную речь | |
- **ESpeech-TTS-1 [SFT] 95K** - чекпоинт с 95000 шагов (на нем основана RL V1) | |
- **ESpeech-TTS-1 [SFT] 265K** - чекпоинт с 265000 шагов (на нем основана RL V2) | |
""") | |
model_choice = gr.Dropdown( | |
choices=list(MODEL_REPOS.keys()), | |
label="Select Model", | |
value=list(MODEL_REPOS.keys())[0], | |
interactive=True | |
) | |
with gr.Row(): | |
with gr.Column(): | |
ref_audio_input = gr.Audio(label="Reference Audio", type="filepath") | |
ref_text_input = gr.Textbox( | |
label="Reference Text", | |
lines=2, | |
placeholder="leave empty → ASR will transcribe" | |
) | |
with gr.Column(): | |
gen_text_input = gr.Textbox( | |
label="Text to Generate", | |
lines=5, | |
max_lines=20, | |
placeholder="Enter text to synthesize..." | |
) | |
# Кнопка для обработки текста без синтеза | |
process_text_btn = gr.Button("✏️ Process Text (Add Accents)", variant="secondary") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Advanced Settings", open=False): | |
seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) | |
remove_silence = gr.Checkbox(label="Remove Silences", value=False) | |
speed_slider = gr.Slider(label="Speed", minimum=0.3, maximum=2.0, value=1.0, step=0.1) | |
nfe_slider = gr.Slider(label="NFE Steps", minimum=4, maximum=64, value=48, step=2) | |
cross_fade_slider = gr.Slider(label="Cross-Fade Duration (s)", minimum=0.0, maximum=1.0, value=0.15, step=0.01) | |
generate_btn = gr.Button("🎤 Generate Speech", variant="primary", size="lg") | |
with gr.Row(): | |
audio_output = gr.Audio(label="Generated Audio", type="numpy") | |
spectrogram_output = gr.Image(label="Spectrogram", type="filepath") | |
# Примеры | |
gr.Markdown("## 🎯 Example") | |
gr.Examples( | |
examples=[ | |
[ | |
EXAMPLE_REF_AUDIO, # ref_audio | |
"", # ref_text (empty for ASR) | |
EXAMPLE_TEXT, # gen_text | |
False, # remove_silence | |
42, # seed | |
0.15, # cross_fade | |
48, # nfe_step | |
1.0, # speed | |
] | |
], | |
inputs=[ | |
ref_audio_input, | |
ref_text_input, | |
gen_text_input, | |
remove_silence, | |
seed_input, | |
cross_fade_slider, | |
nfe_slider, | |
speed_slider, | |
], | |
outputs=[audio_output, spectrogram_output, ref_text_input, gen_text_input], | |
fn=lambda *args: synthesize(model_choice.value, *args), | |
cache_examples=True, | |
run_on_click=True, | |
) | |
# Обработка текста без синтеза | |
process_text_btn.click( | |
process_texts_only, | |
inputs=[ref_text_input, gen_text_input], | |
outputs=[ref_text_input, gen_text_input] | |
) | |
# Основная генерация | |
generate_btn.click( | |
synthesize, | |
inputs=[ | |
model_choice, | |
ref_audio_input, | |
ref_text_input, | |
gen_text_input, | |
remove_silence, | |
seed_input, | |
cross_fade_slider, | |
nfe_slider, | |
speed_slider, | |
], | |
outputs=[audio_output, spectrogram_output, ref_text_input, gen_text_input] | |
) | |
if __name__ == "__main__": | |
#app.launch(server_name="0.0.0.0", server_port=7860) | |
app.launch() |