import os from dataclasses import asdict from text import symbols import torch import torchaudio from utils.audio import LogMelSpectrogram from config import ModelConfig, VocosConfig, MelConfig from models.model import StableTTS from vocos_pytorch.models.model import Vocos from text.english import english_to_ipa2 from text import cleaned_text_to_sequence from datas.dataset import intersperse import gradio as gr import numpy as np import matplotlib.pyplot as plt from pathlib import Path device = 'cpu' @ torch.inference_mode() def inference(text: str, ref_audio: torch.Tensor, checkpoint_path: str, step: int=10) -> torch.Tensor: global last_checkpoint_path if checkpoint_path != last_checkpoint_path: tts_model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) last_checkpoint_path = checkpoint_path phonemizer = english_to_ipa2 # prepare input for tts model x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0) x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device) waveform, sr = torchaudio.load(ref_audio) if sr != sample_rate: waveform = torchaudio.functional.resample(waveform, sr, sample_rate) y = mel_extractor(waveform).to(device) # inference mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs'] audio = vocoder(mel) # process output for gradio audio_output = (sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel return audio_output, mel_output def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path, vocoder_checkpoint_path): tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config)) mel_extractor = LogMelSpectrogram(mel_config) vocoder = Vocos(vocoder_config, mel_config) # tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu')) tts_model.to(device) tts_model.eval() vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu')) vocoder.to(device) vocoder.eval() return tts_model, mel_extractor, vocoder def plot_mel_spectrogram(mel_spectrogram): fig, ax = plt.subplots(figsize=(20, 8)) ax.imshow(mel_spectrogram, aspect='auto', origin='lower') plt.axis('off') fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges return fig def main(): tts_model_config = ModelConfig() mel_config = MelConfig() vocoder_config = VocosConfig() tts_checkpoint_path = './checkpoints' # the folder that contains StableTTS checkpoints vocoder_checkpoint_path = './checkpoints/vocoder.pt' global tts_model, mel_extractor, vocoder, sample_rate, last_checkpoint_path sample_rate = mel_config.sample_rate last_checkpoint_path = None tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path) tts_checkpoint_path = [path for path in Path(tts_checkpoint_path).rglob('*.pt') if 'optimizer' and 'vocoder' not in path.name] audios = list(Path('./audios').rglob('*.wav')) + list(Path('./audios').rglob('*.flac')) # gradio wabui gui_title = 'StableTTS' gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3.""" with gr.Blocks(analytics_enabled=False) as demo: with gr.Row(): with gr.Column(): gr.Markdown(f"# {gui_title}") gr.Markdown(gui_description) with gr.Row(): with gr.Column(): input_text_gr = gr.Textbox( label="Input Text", info="One or two sentences at a time is better. Up to 200 text characters.", value="Today I want to tell you three stories from my life. That's it. No big deal. Just three stories.", ) ref_audio_gr = gr.Dropdown( label='reference audio', choices=audios, value = 0 ) checkpoint_gr = gr.Dropdown( label='checkpoint', choices=tts_checkpoint_path, value = 0 ) step_gr = gr.Slider( label='Step', minimum=1, maximum=40, value=8, step=1 ) tts_button = gr.Button("Send", elem_id="send-btn", visible=True) with gr.Column(): mel_gr = gr.Plot(label="Mel Visual") audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) tts_button.click(inference, [input_text_gr, ref_audio_gr, checkpoint_gr, step_gr], outputs=[audio_gr, mel_gr]) demo.queue() demo.launch(debug=True, show_api=True) if __name__ == '__main__': main()