Spaces:
Runtime error
Runtime error
from __future__ import absolute_import, division, print_function, unicode_literals | |
from typing import Tuple | |
import sys | |
from argparse import ArgumentParser | |
import torch | |
import numpy as np | |
import os | |
import json | |
import torch | |
sys.path.append(os.path.join(os.path.dirname(__file__), "../src/glow_tts")) | |
from scipy.io.wavfile import write | |
from hifi.env import AttrDict | |
from hifi.models import Generator | |
from text import text_to_sequence | |
import commons | |
import models | |
import utils | |
def check_directory(dir): | |
if not os.path.exists(dir): | |
sys.exit("Error: {} directory does not exist".format(dir)) | |
class TextToMel: | |
def __init__(self, glow_model_dir, device="cuda"): | |
self.glow_model_dir = glow_model_dir | |
check_directory(self.glow_model_dir) | |
self.device = device | |
self.hps, self.glow_tts_model = self.load_glow_tts() | |
pass | |
def load_glow_tts(self): | |
hps = utils.get_hparams_from_dir(self.glow_model_dir) | |
checkpoint_path = utils.latest_checkpoint_path(self.glow_model_dir) | |
symbols = list(hps.data.punc) + list(hps.data.chars) | |
glow_tts_model = models.FlowGenerator( | |
len(symbols) + getattr(hps.data, "add_blank", False), | |
out_channels=hps.data.n_mel_channels, | |
**hps.model | |
) # .to(self.device) | |
if self.device == "cuda": | |
glow_tts_model.to("cuda") | |
utils.load_checkpoint(checkpoint_path, glow_tts_model) | |
glow_tts_model.decoder.store_inverse() | |
_ = glow_tts_model.eval() | |
return hps, glow_tts_model | |
def generate_mel(self, text, noise_scale=0.667, length_scale=1.0): | |
symbols = list(self.hps.data.punc) + list(self.hps.data.chars) | |
cleaner = self.hps.data.text_cleaners | |
if getattr(self.hps.data, "add_blank", False): | |
text_norm = text_to_sequence(text, symbols, cleaner) | |
text_norm = commons.intersperse(text_norm, len(symbols)) | |
else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality | |
text = " " + text.strip() + " " | |
text_norm = text_to_sequence(text, symbols, cleaner) | |
sequence = np.array(text_norm)[None, :] | |
del symbols | |
del cleaner | |
del text | |
del text_norm | |
if self.device == "cuda": | |
x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() | |
x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() | |
else: | |
x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).long() | |
x_tst_lengths = torch.tensor([x_tst.shape[1]]) | |
with torch.no_grad(): | |
(y_gen_tst, *_), *_, (attn_gen, *_) = self.glow_tts_model( | |
x_tst, | |
x_tst_lengths, | |
gen=True, | |
noise_scale=noise_scale, | |
length_scale=length_scale, | |
) | |
del x_tst | |
del x_tst_lengths | |
torch.cuda.empty_cache() | |
return y_gen_tst | |
#return y_gen_tst.cpu().detach().numpy() | |
class MelToWav: | |
def __init__(self, hifi_model_dir, device="cuda"): | |
self.hifi_model_dir = hifi_model_dir | |
check_directory(self.hifi_model_dir) | |
self.device = device | |
self.h, self.hifi_gan_generator = self.load_hifi_gan() | |
pass | |
def load_hifi_gan(self): | |
checkpoint_path = utils.latest_checkpoint_path(self.hifi_model_dir, regex="g_*") | |
config_file = os.path.join(self.hifi_model_dir, "config.json") | |
data = open(config_file).read() | |
json_config = json.loads(data) | |
h = AttrDict(json_config) | |
torch.manual_seed(h.seed) | |
generator = Generator(h).to(self.device) | |
assert os.path.isfile(checkpoint_path) | |
print("Loading '{}'".format(checkpoint_path)) | |
state_dict_g = torch.load(checkpoint_path, map_location=self.device) | |
print("Complete.") | |
generator.load_state_dict(state_dict_g["generator"]) | |
generator.eval() | |
generator.remove_weight_norm() | |
return h, generator | |
def generate_wav(self, mel): | |
#mel = torch.FloatTensor(mel).to(self.device) | |
y_g_hat = self.hifi_gan_generator(mel.to(self.device)) # passing through vocoder | |
audio = y_g_hat.squeeze() | |
audio = audio * 32768.0 | |
audio = audio.cpu().detach().numpy().astype("int16") | |
del y_g_hat | |
del mel | |
torch.cuda.empty_cache() | |
return audio, self.h.sampling_rate | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("-m", "--model", required=True, type=str) | |
parser.add_argument("-g", "--gan", required=True, type=str) | |
parser.add_argument("-d", "--device", type=str, default="cpu") | |
parser.add_argument("-t", "--text", type=str, required=True) | |
parser.add_argument("-w", "--wav", type=str, required=True) | |
args = parser.parse_args() | |
text_to_mel = TextToMel(glow_model_dir=args.model, device=args.device) | |
mel_to_wav = MelToWav(hifi_model_dir=args.gan, device=args.device) | |
mel = text_to_mel.generate_mel(args.text) | |
audio, sr = mel_to_wav.generate_wav(mel) | |
write(filename=args.wav, rate=sr, data=audio) | |
pass | |