#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2019 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) """Decode with trained Parallel WaveGAN Generator.""" import argparse import logging import os import time import numpy as np import soundfile as sf import torch import yaml from tqdm import tqdm from parallel_wavegan.datasets import MelDataset from parallel_wavegan.datasets import MelSCPDataset from parallel_wavegan.utils import load_model from parallel_wavegan.utils import read_hdf5 def main(): """Run decoding process.""" parser = argparse.ArgumentParser( description="Decode dumped features with trained Parallel WaveGAN Generator " "(See detail in parallel_wavegan/bin/decode.py)." ) parser.add_argument( "--feats-scp", "--scp", default=None, type=str, help="kaldi-style feats.scp file. " "you need to specify either feats-scp or dumpdir.", ) parser.add_argument( "--dumpdir", default=None, type=str, help="directory including feature files. " "you need to specify either feats-scp or dumpdir.", ) parser.add_argument( "--outdir", type=str, required=True, help="directory to save generated speech.", ) parser.add_argument( "--checkpoint", type=str, required=True, help="checkpoint file to be loaded.", ) parser.add_argument( "--config", default=None, type=str, help="yaml format configuration file. if not explicitly provided, " "it will be searched in the checkpoint directory. (default=None)", ) parser.add_argument( "--normalize-before", default=False, action="store_true", help="whether to perform feature normalization before input to the model. " "if true, it assumes that the feature is de-normalized. this is useful when " "text2mel model and vocoder use different feature statistics.", ) parser.add_argument( "--verbose", type=int, default=1, help="logging level. higher is more logging. (default=1)", ) args = parser.parse_args() # set logger if args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) elif args.verbose > 0: logging.basicConfig( level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) else: logging.basicConfig( level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", ) logging.warning("Skip DEBUG/INFO messages") # check directory existence if not os.path.exists(args.outdir): os.makedirs(args.outdir) # load config if args.config is None: dirname = os.path.dirname(args.checkpoint) args.config = os.path.join(dirname, "config.yml") with open(args.config) as f: config = yaml.load(f, Loader=yaml.Loader) config.update(vars(args)) # check arguments if (args.feats_scp is not None and args.dumpdir is not None) or ( args.feats_scp is None and args.dumpdir is None ): raise ValueError("Please specify either --dumpdir or --feats-scp.") # get dataset if args.dumpdir is not None: if config["format"] == "hdf5": mel_query = "*.h5" mel_load_fn = lambda x: read_hdf5(x, "feats") # NOQA elif config["format"] == "npy": mel_query = "*-feats.npy" mel_load_fn = np.load else: raise ValueError("Support only hdf5 or npy format.") dataset = MelDataset( args.dumpdir, mel_query=mel_query, mel_load_fn=mel_load_fn, return_utt_id=True, ) else: dataset = MelSCPDataset( feats_scp=args.feats_scp, return_utt_id=True, ) logging.info(f"The number of features to be decoded = {len(dataset)}.") # setup model if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") model = load_model(args.checkpoint, config) logging.info(f"Loaded model parameters from {args.checkpoint}.") if args.normalize_before: assert hasattr(model, "mean"), "Feature stats are not registered." assert hasattr(model, "scale"), "Feature stats are not registered." model.remove_weight_norm() model = model.eval().to(device) # start generation total_rtf = 0.0 with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar: for idx, (utt_id, c) in enumerate(pbar, 1): # generate c = torch.tensor(c, dtype=torch.float).to(device) start = time.time() y = model.inference(c, normalize_before=args.normalize_before).view(-1) rtf = (time.time() - start) / (len(y) / config["sampling_rate"]) pbar.set_postfix({"RTF": rtf}) total_rtf += rtf # save as PCM 16 bit wav file sf.write( os.path.join(config["outdir"], f"{utt_id}_gen.wav"), y.cpu().numpy(), config["sampling_rate"], "PCM_16", ) # report average RTF logging.info( f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f})." ) if __name__ == "__main__": main()