""" Utility functions """ import pickle from pathlib import Path import pax import toml import yaml from tacotron import Tacotron def load_tacotron_config(config_file=Path("tacotron.toml")): """ Load the project configurations """ return toml.load(config_file)["tacotron"] def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path): """ load checkpoint from disk """ with open(path, "rb") as f: dic = pickle.load(f) if net is not None: net = net.load_state_dict(dic["model_state_dict"]) if optim is not None: optim = optim.load_state_dict(dic["optim_state_dict"]) return dic["step"], net, optim def create_tacotron_model(config): """ return a random initialized Tacotron model """ return Tacotron( mel_dim=config["MEL_DIM"], attn_bias=config["ATTN_BIAS"], rr=config["RR"], max_rr=config["MAX_RR"], mel_min=config["MEL_MIN"], sigmoid_noise=config["SIGMOID_NOISE"], pad_token=config["PAD_TOKEN"], prenet_dim=config["PRENET_DIM"], attn_hidden_dim=config["ATTN_HIDDEN_DIM"], attn_rnn_dim=config["ATTN_RNN_DIM"], rnn_dim=config["RNN_DIM"], postnet_dim=config["POSTNET_DIM"], text_dim=config["TEXT_DIM"], ) def load_wavegru_config(config_file): """ Load project configurations """ with open(config_file, "r", encoding="utf-8") as f: return yaml.safe_load(f) def load_wavegru_ckpt(net, optim, ckpt_file): """ load training checkpoint from file """ with open(ckpt_file, "rb") as f: dic = pickle.load(f) if net is not None: net = net.load_state_dict(dic["net_state_dict"]) if optim is not None: optim = optim.load_state_dict(dic["optim_state_dict"]) return dic["step"], net, optim