from __future__ import absolute_import, division, print_function, unicode_literals from typing import Tuple import sys from argparse import ArgumentParser import torch import numpy as np import os import json import torch sys.path.append(os.path.join(os.path.dirname(__file__), "../src/glow_tts")) from scipy.io.wavfile import write from hifi.env import AttrDict from hifi.models import Generator from text import text_to_sequence import commons import models import utils def check_directory(dir): if not os.path.exists(dir): sys.exit("Error: {} directory does not exist".format(dir)) class TextToMel: def __init__(self, glow_model_dir, device="cuda"): self.glow_model_dir = glow_model_dir check_directory(self.glow_model_dir) self.device = device self.hps, self.glow_tts_model = self.load_glow_tts() pass def load_glow_tts(self): hps = utils.get_hparams_from_dir(self.glow_model_dir) checkpoint_path = utils.latest_checkpoint_path(self.glow_model_dir) symbols = list(hps.data.punc) + list(hps.data.chars) glow_tts_model = models.FlowGenerator( len(symbols) + getattr(hps.data, "add_blank", False), out_channels=hps.data.n_mel_channels, **hps.model ) # .to(self.device) if self.device == "cuda": glow_tts_model.to("cuda") utils.load_checkpoint(checkpoint_path, glow_tts_model) glow_tts_model.decoder.store_inverse() _ = glow_tts_model.eval() return hps, glow_tts_model def generate_mel(self, text, noise_scale=0.667, length_scale=1.0): symbols = list(self.hps.data.punc) + list(self.hps.data.chars) cleaner = self.hps.data.text_cleaners if getattr(self.hps.data, "add_blank", False): text_norm = text_to_sequence(text, symbols, cleaner) text_norm = commons.intersperse(text_norm, len(symbols)) else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality text = " " + text.strip() + " " text_norm = text_to_sequence(text, symbols, cleaner) sequence = np.array(text_norm)[None, :] del symbols del cleaner del text del text_norm if self.device == "cuda": x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() else: x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).long() x_tst_lengths = torch.tensor([x_tst.shape[1]]) with torch.no_grad(): (y_gen_tst, *_), *_, (attn_gen, *_) = self.glow_tts_model( x_tst, x_tst_lengths, gen=True, noise_scale=noise_scale, length_scale=length_scale, ) del x_tst del x_tst_lengths torch.cuda.empty_cache() return y_gen_tst #return y_gen_tst.cpu().detach().numpy() class MelToWav: def __init__(self, hifi_model_dir, device="cuda"): self.hifi_model_dir = hifi_model_dir check_directory(self.hifi_model_dir) self.device = device self.h, self.hifi_gan_generator = self.load_hifi_gan() pass def load_hifi_gan(self): checkpoint_path = utils.latest_checkpoint_path(self.hifi_model_dir, regex="g_*") config_file = os.path.join(self.hifi_model_dir, "config.json") data = open(config_file).read() json_config = json.loads(data) h = AttrDict(json_config) torch.manual_seed(h.seed) generator = Generator(h).to(self.device) assert os.path.isfile(checkpoint_path) print("Loading '{}'".format(checkpoint_path)) state_dict_g = torch.load(checkpoint_path, map_location=self.device) print("Complete.") generator.load_state_dict(state_dict_g["generator"]) generator.eval() generator.remove_weight_norm() return h, generator def generate_wav(self, mel): #mel = torch.FloatTensor(mel).to(self.device) y_g_hat = self.hifi_gan_generator(mel.to(self.device)) # passing through vocoder audio = y_g_hat.squeeze() audio = audio * 32768.0 audio = audio.cpu().detach().numpy().astype("int16") del y_g_hat del mel torch.cuda.empty_cache() return audio, self.h.sampling_rate if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-m", "--model", required=True, type=str) parser.add_argument("-g", "--gan", required=True, type=str) parser.add_argument("-d", "--device", type=str, default="cpu") parser.add_argument("-t", "--text", type=str, required=True) parser.add_argument("-w", "--wav", type=str, required=True) args = parser.parse_args() text_to_mel = TextToMel(glow_model_dir=args.model, device=args.device) mel_to_wav = MelToWav(hifi_model_dir=args.gan, device=args.device) mel = text_to_mel.generate_mel(args.text) audio, sr = mel_to_wav.generate_wav(mel) write(filename=args.wav, rate=sr, data=audio) pass