KdaiP's picture
Update app.py
907f744 verified
import os
from dataclasses import asdict
from text import symbols
import torch
import torchaudio
from utils.audio import LogMelSpectrogram
from config import ModelConfig, VocosConfig, MelConfig
from models.model import StableTTS
from vocos_pytorch.models.model import Vocos
from text.mandarin import chinese_to_cnm3
from text import cleaned_text_to_sequence
from datas.dataset import intersperse
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ torch.inference_mode()
def inference(text: str, ref_audio: torch.Tensor, checkpoint_path: str, step: int=10) -> torch.Tensor:
global last_checkpoint_path
if checkpoint_path != last_checkpoint_path:
tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
last_checkpoint_path = checkpoint_path
phonemizer = chinese_to_cnm3
# prepare input for tts model
x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
waveform, sr = torchaudio.load(ref_audio)
if sr != sample_rate:
waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
y = mel_extractor(waveform).to(device)
# inference
mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs']
audio = vocoder(mel)
# process output for gradio
audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
return audio_output, mel_output
def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path):
tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
mel_extractor = LogMelSpectrogram(mel_config)
vocoder = Vocos(vocoder_config, mel_config)
# tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
tts_model.to(device)
tts_model.eval()
vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
vocoder.to(device)
vocoder.eval()
return tts_model, mel_extractor, vocoder
def plot_mel_spectrogram(mel_spectrogram):
fig, ax = plt.subplots(figsize=(20, 8))
ax.imshow(mel_spectrogram, aspect='auto', origin='lower')
plt.axis('off')
fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges
return fig
def main():
tts_model_config = ModelConfig()
mel_config = MelConfig()
vocoder_config = VocosConfig()
tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints
vocoder_checkpoint_path = './checkpoints/vocoder.pt'
global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path
sample_rate = mel_config.sample_rate
last_checkpoint_path = None
tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name]
audios = list(Path('./audios').rglob('*.wav'))
# gradio wabui
gui_title = 'StableTTS'
gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3."""
with gr.Blocks(analytics_enabled=False) as demo:
with gr.Row():
with gr.Column():
gr.Markdown(f"# {gui_title}")
gr.Markdown(gui_description)
with gr.Row():
with gr.Column():
input_text_gr = gr.Textbox(
label="Input Text",
info="One or two sentences at a time is better. Up to 200 text characters.",
value="在绝望的深渊里为你照亮前路,你害怕阳光,那我就化作繁星,永远陪伴你。",
)
ref_audio_gr = gr.Dropdown(
label='reference audio',
choices=audios,
value = 0
)
checkpoint_gr = gr.Dropdown(
label='checkpoint',
choices=tts_checkpoint_path,
value = 0
)
step_gr = gr.Slider(
label='Step',
minimum=1,
maximum=40,
value=8,
step=1
)
tts_button = gr.Button("Send", elem_id="send-btn", visible=True)
with gr.Column():
mel_gr = gr.Plot(label="Mel Visual")
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
tts_button.click(inference, [input_text_gr, ref_audio_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr])
demo.queue()
demo.launch(debug=True, show_api=True)
if __name__ == '__main__':
main()