import gradio as gr import argparse import json import datetime as dt import numpy as np from scipy.io.wavfile import write from utils import plot_tensor, save_plot import torch import params from model import GradTTSVitE2E from text import text_to_sequence, cmudict from text.symbols import symbols from utils import intersperse from data import TextMelSpeakerMelRefDataset,VCTKMelRefDataset import sys sys.path.append('./hifi-gan/') from env import AttrDict from models import Generator as HiFiGAN HIFIGAN_CONFIG = './checkpts/hifigan-config.json' HIFIGAN_CHECKPT = './checkpts/g.pt' # generator_VCTK generator_v1 g train_filelist_path = params.train_filelist_path valid_filelist_path = params.valid_filelist_path cmudict_path = params.cmudict_path add_blank = params.add_blank n_spks = params.n_spks spk_emb_dim = params.spk_emb_dim log_dir = params.log_dir n_epochs = params.n_epochs batch_size = params.batch_size out_size = params.out_size learning_rate = params.learning_rate random_seed = params.seed nsymbols = len(symbols) + 1 if add_blank else len(symbols) n_enc_channels = params.n_enc_channels filter_channels = params.filter_channels filter_channels_dp = params.filter_channels_dp n_enc_layers = params.n_enc_layers enc_kernel = params.enc_kernel enc_dropout = params.enc_dropout n_heads = params.n_heads window_size = params.window_size n_feats = params.n_feats n_fft = params.n_fft sample_rate = params.sample_rate hop_length = params.hop_length win_length = params.win_length f_min = params.f_min f_max = params.f_max dec_dim = params.dec_dim beta_min = params.beta_min beta_max = params.beta_max pe_scale = params.pe_scale print('Initializing Grad-TTS-Ref...') generator = GradTTSVitE2E(len(symbols)+1, params.n_spks, params.spk_emb_dim, params.n_enc_channels, params.filter_channels, params.filter_channels_dp, params.n_heads, params.n_enc_layers, params.enc_kernel, params.enc_dropout, params.window_size, params.n_feats, params.dec_dim, params.beta_min, params.beta_max, params.pe_scale) generator.load_state_dict(torch.load('logs_best_ckpt/tts_ref_in_context_VCTK_set1/grad_75.pt', map_location=lambda loc, storage: loc)) #_ = generator.cuda().eval() _ = generator.eval() print(f'Number of parameters: {generator.nparams}') print('Initializing HiFi-GAN...') with open(HIFIGAN_CONFIG) as f: h = AttrDict(json.load(f)) vocoder = HiFiGAN(h) vocoder.load_state_dict(torch.load(HIFIGAN_CHECKPT, map_location=lambda loc, storage: loc)['generator']) #_ = vocoder.cuda().eval() _ = vocoder.eval() vocoder.remove_weight_norm() cmu = cmudict.CMUDict('./resources/cmu_dictionary') test_dataset = TextMelSpeakerMelRefDataset(valid_filelist_path, cmudict_path, add_blank, n_fft, n_feats, sample_rate, hop_length, win_length, f_min, f_max) def spec_rr (mu_y): for i in range (0,2): mask_length = np.random.randint(0, mu_y.shape[2] // 3 + 1, size=1)[0] start_idx = np.random.randint(0, mu_y.shape[2]-mask_length, size=1)[0] end_idx = start_idx + mask_length + 1 mu_y = torch.cat((mu_y[:,:,:start_idx] , mu_y[:,:,end_idx:] ), dim=-1) return mu_y def tts (text,y_ref_) : with torch.no_grad(): y_ref_ = test_dataset.get_mel_web(y_ref_) #print(f'Synthesizing {i} text...', end=' ') x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols))).cuda()[None] x_lengths = torch.LongTensor([x.shape[-1]]).cuda() y_ref_lengths_ = y_ref_.shape[-1] y_ref = torch.zeros((1, n_feats, y_ref_lengths_), dtype=torch.float32) y_ref_lengths = torch.zeros((1), dtype=torch.int32) y_ref[0, :, :y_ref_lengths_] = y_ref_ ''' rep_times = 600 // y_ref_lengths_ y_ref1 = y_ref for _ in range (rep_times): y_ref = torch.cat([y_ref,y_ref1],dim=2) y_ref = spec_rr(y_ref) print(y_ref_lengths_, y_ref.shape[-1]) ''' t = dt.datetime.now() y_enc, y_dec, attn = generator.forward(x, x_lengths, 50, y_ref, torch.LongTensor( y_ref.shape[-1]) , temperature=1.5, stoc=False, spk=None, length_scale=0.91) t = (dt.datetime.now() - t).total_seconds() print(f'Grad-TTS RTF: {t * 22050 / (y_dec.shape[-1] * 256)}') audio = (vocoder.forward(y_dec).cpu().squeeze().numpy() * 32768).astype(np.int16) write('out.wav', 22050, audio) return 'out.wav' def tts_stack (text,y_ref1,y_ref2,y_ref3) : with torch.no_grad(): if y_ref1 == None: y_ref1 = torch.zeros((1, n_feats, 1), dtype=torch.float32) else: y_ref1 = test_dataset.get_mel_web(y_ref1) if y_ref2 == None: y_ref2 = torch.zeros((1, n_feats, 1), dtype=torch.float32) else: y_ref2 = test_dataset.get_mel_web(y_ref2) if y_ref3 == None: y_ref3 = torch.zeros((1, n_feats, 1), dtype=torch.float32) else: y_ref3 = test_dataset.get_mel_web(y_ref3) #print(f'Synthesizing {i} text...', end=' ') x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=cmu), len(symbols)))[None] x_lengths = torch.LongTensor([x.shape[-1]]) #print(y_ref1.shape,y_ref2.shape,y_ref3.shape) y_ref = torch.zeros((1, n_feats, y_ref1.shape[-1]+y_ref2.shape[-1]+y_ref3.shape[-1]), dtype=torch.float32) y_ref[0, :, : y_ref1.shape[-1]] = y_ref1 y_ref[0, :, y_ref1.shape[-1] : (y_ref1.shape[-1] + y_ref2.shape[-1])] = y_ref2 y_ref[0, :, (y_ref1.shape[-1] + y_ref2.shape[-1]) : (y_ref1.shape[-1] + y_ref2.shape[-1] + y_ref3.shape[-1])] = y_ref3 ''' rep_times = 600 // y_ref_lengths_ y_ref1 = y_ref for _ in range (rep_times): y_ref = torch.cat([y_ref,y_ref1],dim=2) y_ref = spec_rr(y_ref) print(y_ref_lengths_, y_ref.shape[-1]) ''' y_ref = spec_rr(y_ref) t = dt.datetime.now() y_enc, y_dec, attn = generator.forward(x, x_lengths, 50, y_ref, torch.LongTensor( y_ref.shape[-1]) , temperature=1.5, stoc=False, spk=None, length_scale=0.91) t = (dt.datetime.now() - t).total_seconds() print(f'Grad-TTS RTF: {t * 22050 / (y_dec.shape[-1] * 256)}') audio = (vocoder.forward(y_dec).cpu().squeeze().numpy() * 32768).astype(np.int16) write('out.wav', 22050, audio) return 'out.wav' demo_play = gr.Interface(fn = tts_stack, inputs = [gr.Textbox(max_lines=6, label="Input Text", value="Massey Voice is a zero shot speech synthesis model with low-resource settings, which can generate high-quality audio", info="Up to 200 characters"), gr.Audio(type='filepath' ), gr.Audio(type='filepath' ), gr.Audio(type='filepath' )], outputs = 'audio', title = 'MasseyVoice', description = '''
MasseyVoice