conex / espnet /bin /tts_train.py
tobiasc's picture
Initial commit
ad16788
#!/usr/bin/env python3
# Copyright 2018 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Text-to-speech model training script."""
import logging
import os
import random
import subprocess
import sys
import configargparse
import numpy as np
from espnet import __version__
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.cli_utils import strtobool
from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES
# NOTE: you need this func to generate our sphinx doc
def get_parser():
"""Get parser of training arguments."""
parser = configargparse.ArgumentParser(
description="Train a new text-to-speech (TTS) model on one CPU, "
"one or multiple GPUs",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
# general configuration
parser.add("--config", is_config_file=True, help="config file path")
parser.add(
"--config2",
is_config_file=True,
help="second config file path that overwrites the settings in `--config`.",
)
parser.add(
"--config3",
is_config_file=True,
help="third config file path that overwrites "
"the settings in `--config` and `--config2`.",
)
parser.add_argument(
"--ngpu",
default=None,
type=int,
help="Number of GPUs. If not given, use all visible devices",
)
parser.add_argument(
"--backend",
default="pytorch",
type=str,
choices=["chainer", "pytorch"],
help="Backend library",
)
parser.add_argument("--outdir", type=str, required=True, help="Output directory")
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
parser.add_argument("--seed", default=1, type=int, help="Random seed")
parser.add_argument(
"--resume",
"-r",
default="",
type=str,
nargs="?",
help="Resume the training from snapshot",
)
parser.add_argument(
"--minibatches",
"-N",
type=int,
default="-1",
help="Process only N minibatches (for debug)",
)
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--tensorboard-dir",
default=None,
type=str,
nargs="?",
help="Tensorboard log directory path",
)
parser.add_argument(
"--eval-interval-epochs", default=1, type=int, help="Evaluation interval epochs"
)
parser.add_argument(
"--save-interval-epochs", default=1, type=int, help="Save interval epochs"
)
parser.add_argument(
"--report-interval-iters",
default=100,
type=int,
help="Report interval iterations",
)
# task related
parser.add_argument(
"--train-json", type=str, required=True, help="Filename of training json"
)
parser.add_argument(
"--valid-json", type=str, required=True, help="Filename of validation json"
)
# network architecture
parser.add_argument(
"--model-module",
type=str,
default="espnet.nets.pytorch_backend.e2e_tts_tacotron2:Tacotron2",
help="model defined module",
)
# minibatch related
parser.add_argument(
"--sortagrad",
default=0,
type=int,
nargs="?",
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs",
)
parser.add_argument(
"--batch-sort-key",
default="shuffle",
type=str,
choices=["shuffle", "output", "input"],
nargs="?",
help='Batch sorting key. "shuffle" only work with --batch-count "seq".',
)
parser.add_argument(
"--batch-count",
default="auto",
choices=BATCH_COUNT_CHOICES,
help="How to count batch_size. "
"The default (auto) will find how to count by args.",
)
parser.add_argument(
"--batch-size",
"--batch-seqs",
"-b",
default=0,
type=int,
help="Maximum seqs in a minibatch (0 to disable)",
)
parser.add_argument(
"--batch-bins",
default=0,
type=int,
help="Maximum bins in a minibatch (0 to disable)",
)
parser.add_argument(
"--batch-frames-in",
default=0,
type=int,
help="Maximum input frames in a minibatch (0 to disable)",
)
parser.add_argument(
"--batch-frames-out",
default=0,
type=int,
help="Maximum output frames in a minibatch (0 to disable)",
)
parser.add_argument(
"--batch-frames-inout",
default=0,
type=int,
help="Maximum input+output frames in a minibatch (0 to disable)",
)
parser.add_argument(
"--maxlen-in",
"--batch-seq-maxlen-in",
default=100,
type=int,
metavar="ML",
help="When --batch-count=seq, "
"batch size is reduced if the input sequence length > ML.",
)
parser.add_argument(
"--maxlen-out",
"--batch-seq-maxlen-out",
default=200,
type=int,
metavar="ML",
help="When --batch-count=seq, "
"batch size is reduced if the output sequence length > ML",
)
parser.add_argument(
"--num-iter-processes",
default=0,
type=int,
help="Number of processes of iterator",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"--use-speaker-embedding",
default=False,
type=strtobool,
help="Whether to use speaker embedding",
)
parser.add_argument(
"--use-second-target",
default=False,
type=strtobool,
help="Whether to use second target",
)
# optimization related
parser.add_argument(
"--opt", default="adam", type=str, choices=["adam", "noam"], help="Optimizer"
)
parser.add_argument(
"--accum-grad", default=1, type=int, help="Number of gradient accumuration"
)
parser.add_argument(
"--lr", default=1e-3, type=float, help="Learning rate for optimizer"
)
parser.add_argument("--eps", default=1e-6, type=float, help="Epsilon for optimizer")
parser.add_argument(
"--weight-decay",
default=1e-6,
type=float,
help="Weight decay coefficient for optimizer",
)
parser.add_argument(
"--epochs", "-e", default=30, type=int, help="Number of maximum epochs"
)
parser.add_argument(
"--early-stop-criterion",
default="validation/main/loss",
type=str,
nargs="?",
help="Value to monitor to trigger an early stopping of the training",
)
parser.add_argument(
"--patience",
default=3,
type=int,
nargs="?",
help="Number of epochs to wait "
"without improvement before stopping the training",
)
parser.add_argument(
"--grad-clip", default=1, type=float, help="Gradient norm threshold to clip"
)
parser.add_argument(
"--num-save-attention",
default=5,
type=int,
help="Number of samples of attention to be saved",
)
parser.add_argument(
"--keep-all-data-on-mem",
default=False,
type=strtobool,
help="Whether to keep all data on memory",
)
# finetuning related
parser.add_argument(
"--enc-init",
default=None,
type=str,
help="Pre-trained TTS model path to initialize encoder.",
)
parser.add_argument(
"--enc-init-mods",
default="enc.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules to initialize, separated by a comma.",
)
parser.add_argument(
"--dec-init",
default=None,
type=str,
help="Pre-trained TTS model path to initialize decoder.",
)
parser.add_argument(
"--dec-init-mods",
default="dec.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of decoder modules to initialize, separated by a comma.",
)
parser.add_argument(
"--freeze-mods",
default=None,
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of modules to freeze (not to train), separated by a comma.",
)
return parser
def main(cmd_args):
"""Run training."""
parser = get_parser()
args, _ = parser.parse_known_args(cmd_args)
from espnet.utils.dynamic_import import dynamic_import
model_class = dynamic_import(args.model_module)
assert issubclass(model_class, TTSInterface)
model_class.add_arguments(parser)
args = parser.parse_args(cmd_args)
# add version info in args
args.version = __version__
# logging info
if args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
# If --ngpu is not given,
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
# 2. if nvidia-smi exists, use all devices
# 3. else ngpu=0
if args.ngpu is None:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is not None:
ngpu = len(cvd.split(","))
else:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
try:
p = subprocess.run(
["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
except (subprocess.CalledProcessError, FileNotFoundError):
ngpu = 0
else:
ngpu = len(p.stderr.decode().split("\n")) - 1
args.ngpu = ngpu
else:
ngpu = args.ngpu
logging.info(f"ngpu: {ngpu}")
# set random seed
logging.info("random seed = %d" % args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
if args.backend == "pytorch":
from espnet.tts.pytorch_backend.tts import train
train(args)
else:
raise NotImplementedError("Only pytorch is supported.")
if __name__ == "__main__":
main(sys.argv[1:])