|
from abc import ABC |
|
from abc import abstractmethod |
|
import argparse |
|
from dataclasses import dataclass |
|
from distutils.version import LooseVersion |
|
import functools |
|
import logging |
|
import os |
|
from pathlib import Path |
|
import sys |
|
from typing import Any |
|
from typing import Callable |
|
from typing import Dict |
|
from typing import List |
|
from typing import Optional |
|
from typing import Sequence |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import humanfriendly |
|
import numpy as np |
|
import torch |
|
import torch.multiprocessing |
|
import torch.nn |
|
import torch.optim |
|
from torch.utils.data import DataLoader |
|
from typeguard import check_argument_types |
|
from typeguard import check_return_type |
|
import wandb |
|
import yaml |
|
|
|
from espnet import __version__ |
|
from espnet.utils.cli_utils import get_commandline_args |
|
from espnet2.iterators.abs_iter_factory import AbsIterFactory |
|
from espnet2.iterators.chunk_iter_factory import ChunkIterFactory |
|
from espnet2.iterators.multiple_iter_factory import MultipleIterFactory |
|
from espnet2.iterators.sequence_iter_factory import SequenceIterFactory |
|
from espnet2.main_funcs.collect_stats import collect_stats |
|
from espnet2.optimizers.sgd import SGD |
|
from espnet2.samplers.build_batch_sampler import BATCH_TYPES |
|
from espnet2.samplers.build_batch_sampler import build_batch_sampler |
|
from espnet2.samplers.unsorted_batch_sampler import UnsortedBatchSampler |
|
from espnet2.schedulers.noam_lr import NoamLR |
|
from espnet2.schedulers.warmup_lr import WarmupLR |
|
from espnet2.torch_utils.load_pretrained_model import load_pretrained_model |
|
from espnet2.torch_utils.model_summary import model_summary |
|
from espnet2.torch_utils.pytorch_version import pytorch_cudnn_version |
|
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed |
|
from espnet2.train.abs_espnet_model import AbsESPnetModel |
|
from espnet2.train.class_choices import ClassChoices |
|
from espnet2.train.dataset import AbsDataset |
|
from espnet2.train.dataset import DATA_TYPES |
|
from espnet2.train.dataset import ESPnetDataset |
|
from espnet2.train.distributed_utils import DistributedOption |
|
from espnet2.train.distributed_utils import free_port |
|
from espnet2.train.distributed_utils import get_master_port |
|
from espnet2.train.distributed_utils import get_node_rank |
|
from espnet2.train.distributed_utils import get_num_nodes |
|
from espnet2.train.distributed_utils import resolve_distributed_mode |
|
from espnet2.train.iterable_dataset import IterableESPnetDataset |
|
from espnet2.train.trainer import Trainer |
|
from espnet2.utils.build_dataclass import build_dataclass |
|
from espnet2.utils import config_argparse |
|
from espnet2.utils.get_default_kwargs import get_default_kwargs |
|
from espnet2.utils.nested_dict_action import NestedDictAction |
|
from espnet2.utils.types import humanfriendly_parse_size_or_none |
|
from espnet2.utils.types import int_or_none |
|
from espnet2.utils.types import str2bool |
|
from espnet2.utils.types import str2triple_str |
|
from espnet2.utils.types import str_or_int |
|
from espnet2.utils.types import str_or_none |
|
from espnet2.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"): |
|
from torch.multiprocessing.spawn import ProcessContext |
|
else: |
|
from torch.multiprocessing.spawn import SpawnContext as ProcessContext |
|
|
|
|
|
optim_classes = dict( |
|
adam=torch.optim.Adam, |
|
sgd=SGD, |
|
adadelta=torch.optim.Adadelta, |
|
adagrad=torch.optim.Adagrad, |
|
adamax=torch.optim.Adamax, |
|
asgd=torch.optim.ASGD, |
|
lbfgs=torch.optim.LBFGS, |
|
rmsprop=torch.optim.RMSprop, |
|
rprop=torch.optim.Rprop, |
|
) |
|
if LooseVersion(torch.__version__) >= LooseVersion("1.2.0"): |
|
optim_classes["adamw"] = torch.optim.AdamW |
|
try: |
|
import torch_optimizer |
|
|
|
optim_classes.update( |
|
accagd=torch_optimizer.AccSGD, |
|
adabound=torch_optimizer.AdaBound, |
|
adamod=torch_optimizer.AdaMod, |
|
diffgrad=torch_optimizer.DiffGrad, |
|
lamb=torch_optimizer.Lamb, |
|
novograd=torch_optimizer.NovoGrad, |
|
pid=torch_optimizer.PID, |
|
|
|
|
|
qhm=torch_optimizer.QHM, |
|
radam=torch_optimizer.RAdam, |
|
sgdw=torch_optimizer.SGDW, |
|
yogi=torch_optimizer.Yogi, |
|
) |
|
del torch_optimizer |
|
except ImportError: |
|
pass |
|
try: |
|
import apex |
|
|
|
optim_classes.update( |
|
fusedadam=apex.optimizers.FusedAdam, |
|
fusedlamb=apex.optimizers.FusedLAMB, |
|
fusednovograd=apex.optimizers.FusedNovoGrad, |
|
fusedsgd=apex.optimizers.FusedSGD, |
|
) |
|
del apex |
|
except ImportError: |
|
pass |
|
try: |
|
import fairscale |
|
except ImportError: |
|
fairscale = None |
|
|
|
|
|
scheduler_classes = dict( |
|
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, |
|
lambdalr=torch.optim.lr_scheduler.LambdaLR, |
|
steplr=torch.optim.lr_scheduler.StepLR, |
|
multisteplr=torch.optim.lr_scheduler.MultiStepLR, |
|
exponentiallr=torch.optim.lr_scheduler.ExponentialLR, |
|
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, |
|
) |
|
if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): |
|
scheduler_classes.update( |
|
noamlr=NoamLR, |
|
warmuplr=WarmupLR, |
|
) |
|
if LooseVersion(torch.__version__) >= LooseVersion("1.3.0"): |
|
CosineAnnealingWarmRestarts = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts |
|
scheduler_classes.update( |
|
cycliclr=torch.optim.lr_scheduler.CyclicLR, |
|
onecyclelr=torch.optim.lr_scheduler.OneCycleLR, |
|
CosineAnnealingWarmRestarts=CosineAnnealingWarmRestarts, |
|
) |
|
|
|
optim_classes = {k.lower(): v for k, v in optim_classes.items()} |
|
scheduler_classes = {k.lower(): v for k, v in scheduler_classes.items()} |
|
|
|
|
|
@dataclass |
|
class IteratorOptions: |
|
preprocess_fn: callable |
|
collate_fn: callable |
|
data_path_and_name_and_type: list |
|
shape_files: list |
|
batch_size: int |
|
batch_bins: int |
|
batch_type: str |
|
max_cache_size: float |
|
max_cache_fd: int |
|
distributed: bool |
|
num_batches: Optional[int] |
|
num_iters_per_epoch: Optional[int] |
|
train: bool |
|
|
|
|
|
class AbsTask(ABC): |
|
|
|
|
|
|
|
|
|
num_optimizers: int = 1 |
|
trainer = Trainer |
|
class_choices_list: List[ClassChoices] = [] |
|
|
|
def __init__(self): |
|
raise RuntimeError("This class can't be instantiated.") |
|
|
|
@classmethod |
|
@abstractmethod |
|
def add_task_arguments(cls, parser: argparse.ArgumentParser): |
|
pass |
|
|
|
@classmethod |
|
@abstractmethod |
|
def build_collate_fn( |
|
cls, args: argparse.Namespace, train: bool |
|
) -> Callable[[Sequence[Dict[str, np.ndarray]]], Dict[str, torch.Tensor]]: |
|
"""Return "collate_fn", which is a callable object and given to DataLoader. |
|
|
|
>>> from torch.utils.data import DataLoader |
|
>>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...) |
|
|
|
In many cases, you can use our common collate_fn. |
|
""" |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
@abstractmethod |
|
def build_preprocess_fn( |
|
cls, args: argparse.Namespace, train: bool |
|
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
@abstractmethod |
|
def required_data_names( |
|
cls, train: bool = True, inference: bool = False |
|
) -> Tuple[str, ...]: |
|
"""Define the required names by Task |
|
|
|
This function is used by |
|
>>> cls.check_task_requirements() |
|
If your model is defined as following, |
|
|
|
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel |
|
>>> class Model(AbsESPnetModel): |
|
... def forward(self, input, output, opt=None): pass |
|
|
|
then "required_data_names" should be as |
|
|
|
>>> required_data_names = ('input', 'output') |
|
""" |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
@abstractmethod |
|
def optional_data_names( |
|
cls, train: bool = True, inference: bool = False |
|
) -> Tuple[str, ...]: |
|
"""Define the optional names by Task |
|
|
|
This function is used by |
|
>>> cls.check_task_requirements() |
|
If your model is defined as follows, |
|
|
|
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel |
|
>>> class Model(AbsESPnetModel): |
|
... def forward(self, input, output, opt=None): pass |
|
|
|
then "optional_data_names" should be as |
|
|
|
>>> optional_data_names = ('opt',) |
|
""" |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
@abstractmethod |
|
def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel: |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
def get_parser(cls) -> config_argparse.ArgumentParser: |
|
assert check_argument_types() |
|
|
|
class ArgumentDefaultsRawTextHelpFormatter( |
|
argparse.RawTextHelpFormatter, |
|
argparse.ArgumentDefaultsHelpFormatter, |
|
): |
|
pass |
|
|
|
parser = config_argparse.ArgumentParser( |
|
description="base parser", |
|
formatter_class=ArgumentDefaultsRawTextHelpFormatter, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.set_defaults(required=["output_dir"]) |
|
|
|
group = parser.add_argument_group("Common configuration") |
|
|
|
group.add_argument( |
|
"--print_config", |
|
action="store_true", |
|
help="Print the config file and exit", |
|
) |
|
group.add_argument( |
|
"--log_level", |
|
type=lambda x: x.upper(), |
|
default="INFO", |
|
choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
|
help="The verbose level of logging", |
|
) |
|
group.add_argument( |
|
"--dry_run", |
|
type=str2bool, |
|
default=False, |
|
help="Perform process without training", |
|
) |
|
group.add_argument( |
|
"--iterator_type", |
|
type=str, |
|
choices=["sequence", "chunk", "task", "none"], |
|
default="sequence", |
|
help="Specify iterator type", |
|
) |
|
|
|
group.add_argument("--output_dir", type=str_or_none, default=None) |
|
group.add_argument( |
|
"--ngpu", |
|
type=int, |
|
default=0, |
|
help="The number of gpus. 0 indicates CPU mode", |
|
) |
|
group.add_argument("--seed", type=int, default=0, help="Random seed") |
|
group.add_argument( |
|
"--num_workers", |
|
type=int, |
|
default=1, |
|
help="The number of workers used for DataLoader", |
|
) |
|
group.add_argument( |
|
"--num_att_plot", |
|
type=int, |
|
default=3, |
|
help="The number images to plot the outputs from attention. " |
|
"This option makes sense only when attention-based model", |
|
) |
|
|
|
group = parser.add_argument_group("distributed training related") |
|
group.add_argument( |
|
"--dist_backend", |
|
default="nccl", |
|
type=str, |
|
help="distributed backend", |
|
) |
|
group.add_argument( |
|
"--dist_init_method", |
|
type=str, |
|
default="env://", |
|
help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", ' |
|
'"WORLD_SIZE", and "RANK" are referred.', |
|
) |
|
group.add_argument( |
|
"--dist_world_size", |
|
default=None, |
|
type=int_or_none, |
|
help="number of nodes for distributed training", |
|
) |
|
group.add_argument( |
|
"--dist_rank", |
|
type=int_or_none, |
|
default=None, |
|
help="node rank for distributed training", |
|
) |
|
group.add_argument( |
|
|
|
"--local_rank", |
|
type=int_or_none, |
|
default=None, |
|
help="local rank for distributed training. This option is used if " |
|
"--multiprocessing_distributed=false", |
|
) |
|
group.add_argument( |
|
"--dist_master_addr", |
|
default=None, |
|
type=str_or_none, |
|
help="The master address for distributed training. " |
|
"This value is used when dist_init_method == 'env://'", |
|
) |
|
group.add_argument( |
|
"--dist_master_port", |
|
default=None, |
|
type=int_or_none, |
|
help="The master port for distributed training" |
|
"This value is used when dist_init_method == 'env://'", |
|
) |
|
group.add_argument( |
|
"--dist_launcher", |
|
default=None, |
|
type=str_or_none, |
|
choices=["slurm", "mpi", None], |
|
help="The launcher type for distributed training", |
|
) |
|
group.add_argument( |
|
"--multiprocessing_distributed", |
|
default=False, |
|
type=str2bool, |
|
help="Use multi-processing distributed training to launch " |
|
"N processes per node, which has N GPUs. This is the " |
|
"fastest way to use PyTorch for either single node or " |
|
"multi node data parallel training", |
|
) |
|
group.add_argument( |
|
"--unused_parameters", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to use the find_unused_parameters in " |
|
"torch.nn.parallel.DistributedDataParallel ", |
|
) |
|
group.add_argument( |
|
"--sharded_ddp", |
|
default=False, |
|
type=str2bool, |
|
help="Enable sharded training provided by fairscale", |
|
) |
|
|
|
group = parser.add_argument_group("cudnn mode related") |
|
group.add_argument( |
|
"--cudnn_enabled", |
|
type=str2bool, |
|
default=torch.backends.cudnn.enabled, |
|
help="Enable CUDNN", |
|
) |
|
group.add_argument( |
|
"--cudnn_benchmark", |
|
type=str2bool, |
|
default=torch.backends.cudnn.benchmark, |
|
help="Enable cudnn-benchmark mode", |
|
) |
|
group.add_argument( |
|
"--cudnn_deterministic", |
|
type=str2bool, |
|
default=True, |
|
help="Enable cudnn-deterministic mode", |
|
) |
|
|
|
group = parser.add_argument_group("collect stats mode related") |
|
group.add_argument( |
|
"--collect_stats", |
|
type=str2bool, |
|
default=False, |
|
help='Perform on "collect stats" mode', |
|
) |
|
group.add_argument( |
|
"--write_collected_feats", |
|
type=str2bool, |
|
default=False, |
|
help='Write the output features from the model when "collect stats" mode', |
|
) |
|
|
|
group = parser.add_argument_group("Trainer related") |
|
group.add_argument( |
|
"--max_epoch", |
|
type=int, |
|
default=40, |
|
help="The maximum number epoch to train", |
|
) |
|
group.add_argument( |
|
"--patience", |
|
type=int_or_none, |
|
default=None, |
|
help="Number of epochs to wait without improvement " |
|
"before stopping the training", |
|
) |
|
group.add_argument( |
|
"--val_scheduler_criterion", |
|
type=str, |
|
nargs=2, |
|
default=("valid", "loss"), |
|
help="The criterion used for the value given to the lr scheduler. " |
|
'Give a pair referring the phase, "train" or "valid",' |
|
'and the criterion name. The mode specifying "min" or "max" can ' |
|
"be changed by --scheduler_conf", |
|
) |
|
group.add_argument( |
|
"--early_stopping_criterion", |
|
type=str, |
|
nargs=3, |
|
default=("valid", "loss", "min"), |
|
help="The criterion used for judging of early stopping. " |
|
'Give a pair referring the phase, "train" or "valid",' |
|
'the criterion name and the mode, "min" or "max", e.g. "acc,max".', |
|
) |
|
group.add_argument( |
|
"--best_model_criterion", |
|
type=str2triple_str, |
|
nargs="+", |
|
default=[ |
|
("train", "loss", "min"), |
|
("valid", "loss", "min"), |
|
("train", "acc", "max"), |
|
("valid", "acc", "max"), |
|
], |
|
help="The criterion used for judging of the best model. " |
|
'Give a pair referring the phase, "train" or "valid",' |
|
'the criterion name, and the mode, "min" or "max", e.g. "acc,max".', |
|
) |
|
group.add_argument( |
|
"--keep_nbest_models", |
|
type=int, |
|
nargs="+", |
|
default=[10], |
|
help="Remove previous snapshots excluding the n-best scored epochs", |
|
) |
|
group.add_argument( |
|
"--grad_clip", |
|
type=float, |
|
default=5.0, |
|
help="Gradient norm threshold to clip", |
|
) |
|
group.add_argument( |
|
"--grad_clip_type", |
|
type=float, |
|
default=2.0, |
|
help="The type of the used p-norm for gradient clip. Can be inf", |
|
) |
|
group.add_argument( |
|
"--grad_noise", |
|
type=str2bool, |
|
default=False, |
|
help="The flag to switch to use noise injection to " |
|
"gradients during training", |
|
) |
|
group.add_argument( |
|
"--accum_grad", |
|
type=int, |
|
default=1, |
|
help="The number of gradient accumulation", |
|
) |
|
group.add_argument( |
|
"--no_forward_run", |
|
type=str2bool, |
|
default=False, |
|
help="Just only iterating data loading without " |
|
"model forwarding and training", |
|
) |
|
group.add_argument( |
|
"--resume", |
|
type=str2bool, |
|
default=False, |
|
help="Enable resuming if checkpoint is existing", |
|
) |
|
group.add_argument( |
|
"--train_dtype", |
|
default="float32", |
|
choices=["float16", "float32", "float64"], |
|
help="Data type for training.", |
|
) |
|
group.add_argument( |
|
"--use_amp", |
|
type=str2bool, |
|
default=False, |
|
help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6", |
|
) |
|
group.add_argument( |
|
"--log_interval", |
|
type=int_or_none, |
|
default=None, |
|
help="Show the logs every the number iterations in each epochs at the " |
|
"training phase. If None is given, it is decided according the number " |
|
"of training samples automatically .", |
|
) |
|
group.add_argument( |
|
"--use_tensorboard", |
|
type=str2bool, |
|
default=True, |
|
help="Enable tensorboard logging", |
|
) |
|
group.add_argument( |
|
"--use_wandb", |
|
type=str2bool, |
|
default=False, |
|
help="Enable wandb logging", |
|
) |
|
group.add_argument( |
|
"--wandb_project", |
|
type=str, |
|
default=None, |
|
help="Specify wandb project", |
|
) |
|
group.add_argument( |
|
"--wandb_id", |
|
type=str, |
|
default=None, |
|
help="Specify wandb id", |
|
) |
|
group.add_argument( |
|
"--detect_anomaly", |
|
type=str2bool, |
|
default=False, |
|
help="Set torch.autograd.set_detect_anomaly", |
|
) |
|
|
|
group = parser.add_argument_group("Pretraining model related") |
|
group.add_argument("--pretrain_path", help="This option is obsoleted") |
|
group.add_argument( |
|
"--init_param", |
|
type=str, |
|
default=[], |
|
nargs="*", |
|
help="Specify the file path used for initialization of parameters. " |
|
"The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', " |
|
"where file_path is the model file path, " |
|
"src_key specifies the key of model states to be used in the model file, " |
|
"dst_key specifies the attribute of the model to be initialized, " |
|
"and exclude_keys excludes keys of model states for the initialization." |
|
"e.g.\n" |
|
" # Load all parameters" |
|
" --init_param some/where/model.pth\n" |
|
" # Load only decoder parameters" |
|
" --init_param some/where/model.pth:decoder:decoder\n" |
|
" # Load only decoder parameters excluding decoder.embed" |
|
" --init_param some/where/model.pth:decoder:decoder:decoder.embed\n" |
|
" --init_param some/where/model.pth:decoder:decoder:decoder.embed\n", |
|
) |
|
group.add_argument( |
|
"--freeze_param", |
|
type=str, |
|
default=[], |
|
nargs="*", |
|
help="Freeze parameters", |
|
) |
|
|
|
group = parser.add_argument_group("BatchSampler related") |
|
group.add_argument( |
|
"--num_iters_per_epoch", |
|
type=int_or_none, |
|
default=None, |
|
help="Restrict the number of iterations for training per epoch", |
|
) |
|
group.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=20, |
|
help="The mini-batch size used for training. Used if batch_type='unsorted'," |
|
" 'sorted', or 'folded'.", |
|
) |
|
group.add_argument( |
|
"--valid_batch_size", |
|
type=int_or_none, |
|
default=None, |
|
help="If not given, the value of --batch_size is used", |
|
) |
|
group.add_argument( |
|
"--batch_bins", |
|
type=int, |
|
default=1000000, |
|
help="The number of batch bins. Used if batch_type='length' or 'numel'", |
|
) |
|
group.add_argument( |
|
"--valid_batch_bins", |
|
type=int_or_none, |
|
default=None, |
|
help="If not given, the value of --batch_bins is used", |
|
) |
|
|
|
group.add_argument("--train_shape_file", type=str, action="append", default=[]) |
|
group.add_argument("--valid_shape_file", type=str, action="append", default=[]) |
|
|
|
group = parser.add_argument_group("Sequence iterator related") |
|
_batch_type_help = "" |
|
for key, value in BATCH_TYPES.items(): |
|
_batch_type_help += f'"{key}":\n{value}\n' |
|
group.add_argument( |
|
"--batch_type", |
|
type=str, |
|
default="folded", |
|
choices=list(BATCH_TYPES), |
|
help=_batch_type_help, |
|
) |
|
group.add_argument( |
|
"--valid_batch_type", |
|
type=str_or_none, |
|
default=None, |
|
choices=list(BATCH_TYPES) + [None], |
|
help="If not given, the value of --batch_type is used", |
|
) |
|
group.add_argument("--fold_length", type=int, action="append", default=[]) |
|
group.add_argument( |
|
"--sort_in_batch", |
|
type=str, |
|
default="descending", |
|
choices=["descending", "ascending"], |
|
help="Sort the samples in each mini-batches by the sample " |
|
'lengths. To enable this, "shape_file" must have the length information.', |
|
) |
|
group.add_argument( |
|
"--sort_batch", |
|
type=str, |
|
default="descending", |
|
choices=["descending", "ascending"], |
|
help="Sort mini-batches by the sample lengths", |
|
) |
|
group.add_argument( |
|
"--multiple_iterator", |
|
type=str2bool, |
|
default=False, |
|
help="Use multiple iterator mode", |
|
) |
|
|
|
group = parser.add_argument_group("Chunk iterator related") |
|
group.add_argument( |
|
"--chunk_length", |
|
type=str_or_int, |
|
default=500, |
|
help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'." |
|
"If multiple numbers separated by command are given, " |
|
"one of them is selected randomly for each samples. " |
|
"If two numbers are given with '-', it indicates the range of the choices. " |
|
"Note that if the sequence length is shorter than the all chunk_lengths, " |
|
"the sample is discarded. ", |
|
) |
|
group.add_argument( |
|
"--chunk_shift_ratio", |
|
type=float, |
|
default=0.5, |
|
help="Specify the shift width of chunks. If it's less than 1, " |
|
"allows the overlapping and if bigger than 1, there are some gaps " |
|
"between each chunk.", |
|
) |
|
group.add_argument( |
|
"--num_cache_chunks", |
|
type=int, |
|
default=1024, |
|
help="Shuffle in the specified number of chunks and generate mini-batches " |
|
"More larger this value, more randomness can be obtained.", |
|
) |
|
|
|
group = parser.add_argument_group("Dataset related") |
|
_data_path_and_name_and_type_help = ( |
|
"Give three words splitted by comma. It's used for the training data. " |
|
"e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. " |
|
"The first value, some/path/a.scp, indicates the file path, " |
|
"and the second, foo, is the key name used for the mini-batch data, " |
|
"and the last, sound, decides the file type. " |
|
"This option is repeatable, so you can input any number of features " |
|
"for your task. Supported file types are as follows:\n\n" |
|
) |
|
for key, dic in DATA_TYPES.items(): |
|
_data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n' |
|
|
|
group.add_argument( |
|
"--train_data_path_and_name_and_type", |
|
type=str2triple_str, |
|
action="append", |
|
default=[], |
|
help=_data_path_and_name_and_type_help, |
|
) |
|
group.add_argument( |
|
"--valid_data_path_and_name_and_type", |
|
type=str2triple_str, |
|
action="append", |
|
default=[], |
|
) |
|
group.add_argument( |
|
"--allow_variable_data_keys", |
|
type=str2bool, |
|
default=False, |
|
help="Allow the arbitrary keys for mini-batch with ignoring " |
|
"the task requirements", |
|
) |
|
group.add_argument( |
|
"--max_cache_size", |
|
type=humanfriendly.parse_size, |
|
default=0.0, |
|
help="The maximum cache size for data loader. e.g. 10MB, 20GB.", |
|
) |
|
group.add_argument( |
|
"--max_cache_fd", |
|
type=int, |
|
default=32, |
|
help="The maximum number of file descriptors to be kept " |
|
"as opened for ark files. " |
|
"This feature is only valid when data type is 'kaldi_ark'.", |
|
) |
|
group.add_argument( |
|
"--valid_max_cache_size", |
|
type=humanfriendly_parse_size_or_none, |
|
default=None, |
|
help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. " |
|
"If None, the 5 percent size of --max_cache_size", |
|
) |
|
|
|
group = parser.add_argument_group("Optimizer related") |
|
for i in range(1, cls.num_optimizers + 1): |
|
suf = "" if i == 1 else str(i) |
|
group.add_argument( |
|
f"--optim{suf}", |
|
type=lambda x: x.lower(), |
|
default="adadelta", |
|
choices=list(optim_classes), |
|
help="The optimizer type", |
|
) |
|
group.add_argument( |
|
f"--optim{suf}_conf", |
|
action=NestedDictAction, |
|
default=dict(), |
|
help="The keyword arguments for optimizer", |
|
) |
|
group.add_argument( |
|
f"--scheduler{suf}", |
|
type=lambda x: str_or_none(x.lower()), |
|
default=None, |
|
choices=list(scheduler_classes) + [None], |
|
help="The lr scheduler type", |
|
) |
|
group.add_argument( |
|
f"--scheduler{suf}_conf", |
|
action=NestedDictAction, |
|
default=dict(), |
|
help="The keyword arguments for lr scheduler", |
|
) |
|
|
|
cls.trainer.add_arguments(parser) |
|
cls.add_task_arguments(parser) |
|
|
|
assert check_return_type(parser) |
|
return parser |
|
|
|
@classmethod |
|
def build_optimizers( |
|
cls, |
|
args: argparse.Namespace, |
|
model: torch.nn.Module, |
|
) -> List[torch.optim.Optimizer]: |
|
if cls.num_optimizers != 1: |
|
raise RuntimeError( |
|
"build_optimizers() must be overridden if num_optimizers != 1" |
|
) |
|
|
|
optim_class = optim_classes.get(args.optim) |
|
if optim_class is None: |
|
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") |
|
if args.sharded_ddp: |
|
if fairscale is None: |
|
raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'") |
|
optim = fairscale.optim.oss.OSS( |
|
params=model.parameters(), optim=optim_class, **args.optim_conf |
|
) |
|
else: |
|
optim = optim_class(model.parameters(), **args.optim_conf) |
|
|
|
optimizers = [optim] |
|
return optimizers |
|
|
|
@classmethod |
|
def exclude_opts(cls) -> Tuple[str, ...]: |
|
"""The options not to be shown by --print_config""" |
|
return "required", "print_config", "config", "ngpu" |
|
|
|
@classmethod |
|
def get_default_config(cls) -> Dict[str, Any]: |
|
"""Return the configuration as dict. |
|
|
|
This method is used by print_config() |
|
""" |
|
|
|
def get_class_type(name: str, classes: dict): |
|
_cls = classes.get(name) |
|
if _cls is None: |
|
raise ValueError(f"must be one of {list(classes)}: {name}") |
|
return _cls |
|
|
|
|
|
assert check_argument_types() |
|
parser = cls.get_parser() |
|
args, _ = parser.parse_known_args() |
|
config = vars(args) |
|
|
|
for k in AbsTask.exclude_opts(): |
|
config.pop(k) |
|
|
|
for i in range(1, cls.num_optimizers + 1): |
|
suf = "" if i == 1 else str(i) |
|
name = config[f"optim{suf}"] |
|
optim_class = get_class_type(name, optim_classes) |
|
conf = get_default_kwargs(optim_class) |
|
|
|
conf.update(config[f"optim{suf}_conf"]) |
|
|
|
config[f"optim{suf}_conf"] = conf |
|
|
|
name = config[f"scheduler{suf}"] |
|
if name is not None: |
|
scheduler_class = get_class_type(name, scheduler_classes) |
|
conf = get_default_kwargs(scheduler_class) |
|
|
|
conf.update(config[f"scheduler{suf}_conf"]) |
|
|
|
config[f"scheduler{suf}_conf"] = conf |
|
|
|
for class_choices in cls.class_choices_list: |
|
if getattr(args, class_choices.name) is not None: |
|
class_obj = class_choices.get_class(getattr(args, class_choices.name)) |
|
conf = get_default_kwargs(class_obj) |
|
name = class_choices.name |
|
|
|
conf.update(config[f"{name}_conf"]) |
|
|
|
config[f"{name}_conf"] = conf |
|
return config |
|
|
|
@classmethod |
|
def check_required_command_args(cls, args: argparse.Namespace): |
|
assert check_argument_types() |
|
for k in vars(args): |
|
if "-" in k: |
|
raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")') |
|
|
|
required = ", ".join( |
|
f"--{a}" for a in args.required if getattr(args, a) is None |
|
) |
|
|
|
if len(required) != 0: |
|
parser = cls.get_parser() |
|
parser.print_help(file=sys.stderr) |
|
p = Path(sys.argv[0]).name |
|
print(file=sys.stderr) |
|
print( |
|
f"{p}: error: the following arguments are required: " f"{required}", |
|
file=sys.stderr, |
|
) |
|
sys.exit(2) |
|
|
|
@classmethod |
|
def check_task_requirements( |
|
cls, |
|
dataset: Union[AbsDataset, IterableESPnetDataset], |
|
allow_variable_data_keys: bool, |
|
train: bool, |
|
inference: bool = False, |
|
) -> None: |
|
"""Check if the dataset satisfy the requirement of current Task""" |
|
assert check_argument_types() |
|
mes = ( |
|
f"If you intend to use an additional input, modify " |
|
f'"{cls.__name__}.required_data_names()" or ' |
|
f'"{cls.__name__}.optional_data_names()". ' |
|
f"Otherwise you need to set --allow_variable_data_keys true " |
|
) |
|
|
|
for k in cls.required_data_names(train, inference): |
|
if not dataset.has_name(k): |
|
raise RuntimeError( |
|
f'"{cls.required_data_names(train, inference)}" are required for' |
|
f' {cls.__name__}. but "{dataset.names()}" are input.\n{mes}' |
|
) |
|
if not allow_variable_data_keys: |
|
task_keys = cls.required_data_names( |
|
train, inference |
|
) + cls.optional_data_names(train, inference) |
|
for k in dataset.names(): |
|
if k not in task_keys: |
|
raise RuntimeError( |
|
f"The data-name must be one of {task_keys} " |
|
f'for {cls.__name__}: "{k}" is not allowed.\n{mes}' |
|
) |
|
|
|
@classmethod |
|
def print_config(cls, file=sys.stdout) -> None: |
|
assert check_argument_types() |
|
|
|
config = cls.get_default_config() |
|
file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False)) |
|
|
|
@classmethod |
|
def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None): |
|
assert check_argument_types() |
|
print(get_commandline_args(), file=sys.stderr) |
|
if args is None: |
|
parser = cls.get_parser() |
|
args = parser.parse_args(cmd) |
|
args.version = __version__ |
|
if args.pretrain_path is not None: |
|
raise RuntimeError("--pretrain_path is deprecated. Use --init_param") |
|
if args.print_config: |
|
cls.print_config() |
|
sys.exit(0) |
|
cls.check_required_command_args(args) |
|
|
|
|
|
resolve_distributed_mode(args) |
|
if not args.distributed or not args.multiprocessing_distributed: |
|
cls.main_worker(args) |
|
|
|
else: |
|
assert args.ngpu > 1, args.ngpu |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher) |
|
if num_nodes == 1: |
|
args.dist_master_addr = "localhost" |
|
args.dist_rank = 0 |
|
|
|
if ( |
|
args.dist_init_method == "env://" |
|
and get_master_port(args.dist_master_port) is None |
|
): |
|
|
|
args.dist_master_port = free_port() |
|
|
|
|
|
args.dist_world_size = args.ngpu * num_nodes |
|
node_rank = get_node_rank(args.dist_rank, args.dist_launcher) |
|
|
|
|
|
|
|
error_queues = [] |
|
processes = [] |
|
mp = torch.multiprocessing.get_context("spawn") |
|
for i in range(args.ngpu): |
|
|
|
local_args = argparse.Namespace(**vars(args)) |
|
|
|
local_args.local_rank = i |
|
local_args.dist_rank = args.ngpu * node_rank + i |
|
local_args.ngpu = 1 |
|
|
|
process = mp.Process( |
|
target=cls.main_worker, |
|
args=(local_args,), |
|
daemon=False, |
|
) |
|
process.start() |
|
processes.append(process) |
|
error_queues.append(mp.SimpleQueue()) |
|
|
|
while not ProcessContext(processes, error_queues).join(): |
|
pass |
|
|
|
@classmethod |
|
def main_worker(cls, args: argparse.Namespace): |
|
assert check_argument_types() |
|
|
|
|
|
distributed_option = build_dataclass(DistributedOption, args) |
|
|
|
distributed_option.init_options() |
|
|
|
|
|
if not distributed_option.distributed or distributed_option.dist_rank == 0: |
|
if not distributed_option.distributed: |
|
_rank = "" |
|
else: |
|
_rank = ( |
|
f":{distributed_option.dist_rank}/" |
|
f"{distributed_option.dist_world_size}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=args.log_level, |
|
format=f"[{os.uname()[1].split('.')[0]}{_rank}]" |
|
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
else: |
|
|
|
logging.basicConfig( |
|
level="ERROR", |
|
format=f"[{os.uname()[1].split('.')[0]}" |
|
f":{distributed_option.dist_rank}/{distributed_option.dist_world_size}]" |
|
f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
|
|
distributed_option.init_torch_distributed() |
|
|
|
|
|
set_all_random_seed(args.seed) |
|
torch.backends.cudnn.enabled = args.cudnn_enabled |
|
torch.backends.cudnn.benchmark = args.cudnn_benchmark |
|
torch.backends.cudnn.deterministic = args.cudnn_deterministic |
|
if args.detect_anomaly: |
|
logging.info("Invoking torch.autograd.set_detect_anomaly(True)") |
|
torch.autograd.set_detect_anomaly(args.detect_anomaly) |
|
|
|
|
|
model = cls.build_model(args=args) |
|
if not isinstance(model, AbsESPnetModel): |
|
raise RuntimeError( |
|
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
|
) |
|
model = model.to( |
|
dtype=getattr(torch, args.train_dtype), |
|
device="cuda" if args.ngpu > 0 else "cpu", |
|
) |
|
for t in args.freeze_param: |
|
for k, p in model.named_parameters(): |
|
if k.startswith(t + ".") or k == t: |
|
logging.info(f"Setting {k}.requires_grad = False") |
|
p.requires_grad = False |
|
|
|
|
|
optimizers = cls.build_optimizers(args, model=model) |
|
|
|
|
|
schedulers = [] |
|
for i, optim in enumerate(optimizers, 1): |
|
suf = "" if i == 1 else str(i) |
|
name = getattr(args, f"scheduler{suf}") |
|
conf = getattr(args, f"scheduler{suf}_conf") |
|
if name is not None: |
|
cls_ = scheduler_classes.get(name) |
|
if cls_ is None: |
|
raise ValueError( |
|
f"must be one of {list(scheduler_classes)}: {name}" |
|
) |
|
scheduler = cls_(optim, **conf) |
|
else: |
|
scheduler = None |
|
|
|
schedulers.append(scheduler) |
|
|
|
logging.info(pytorch_cudnn_version()) |
|
logging.info(model_summary(model)) |
|
for i, (o, s) in enumerate(zip(optimizers, schedulers), 1): |
|
suf = "" if i == 1 else str(i) |
|
logging.info(f"Optimizer{suf}:\n{o}") |
|
logging.info(f"Scheduler{suf}: {s}") |
|
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
if not distributed_option.distributed or distributed_option.dist_rank == 0: |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
with (output_dir / "config.yaml").open("w", encoding="utf-8") as f: |
|
logging.info( |
|
f'Saving the configuration in {output_dir / "config.yaml"}' |
|
) |
|
yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False) |
|
|
|
|
|
for p in args.init_param: |
|
logging.info(f"Loading pretrained params from {p}") |
|
load_pretrained_model( |
|
model=model, |
|
init_param=p, |
|
|
|
|
|
map_location=f"cuda:{torch.cuda.current_device()}" |
|
if args.ngpu > 0 |
|
else "cpu", |
|
) |
|
|
|
if args.dry_run: |
|
pass |
|
elif args.collect_stats: |
|
|
|
|
|
|
|
logging.info(args) |
|
|
|
if args.valid_batch_size is None: |
|
args.valid_batch_size = args.batch_size |
|
|
|
if len(args.train_shape_file) != 0: |
|
train_key_file = args.train_shape_file[0] |
|
else: |
|
train_key_file = None |
|
if len(args.valid_shape_file) != 0: |
|
valid_key_file = args.valid_shape_file[0] |
|
else: |
|
valid_key_file = None |
|
|
|
collect_stats( |
|
model=model, |
|
train_iter=cls.build_streaming_iterator( |
|
data_path_and_name_and_type=args.train_data_path_and_name_and_type, |
|
key_file=train_key_file, |
|
batch_size=args.batch_size, |
|
dtype=args.train_dtype, |
|
num_workers=args.num_workers, |
|
allow_variable_data_keys=args.allow_variable_data_keys, |
|
ngpu=args.ngpu, |
|
preprocess_fn=cls.build_preprocess_fn(args, train=False), |
|
collate_fn=cls.build_collate_fn(args, train=False), |
|
), |
|
valid_iter=cls.build_streaming_iterator( |
|
data_path_and_name_and_type=args.valid_data_path_and_name_and_type, |
|
key_file=valid_key_file, |
|
batch_size=args.valid_batch_size, |
|
dtype=args.train_dtype, |
|
num_workers=args.num_workers, |
|
allow_variable_data_keys=args.allow_variable_data_keys, |
|
ngpu=args.ngpu, |
|
preprocess_fn=cls.build_preprocess_fn(args, train=False), |
|
collate_fn=cls.build_collate_fn(args, train=False), |
|
), |
|
output_dir=output_dir, |
|
ngpu=args.ngpu, |
|
log_interval=args.log_interval, |
|
write_collected_feats=args.write_collected_feats, |
|
) |
|
else: |
|
|
|
|
|
if args.multiple_iterator: |
|
train_iter_factory = cls.build_multiple_iter_factory( |
|
args=args, |
|
distributed_option=distributed_option, |
|
mode="train", |
|
) |
|
else: |
|
train_iter_factory = cls.build_iter_factory( |
|
args=args, |
|
distributed_option=distributed_option, |
|
mode="train", |
|
) |
|
valid_iter_factory = cls.build_iter_factory( |
|
args=args, |
|
distributed_option=distributed_option, |
|
mode="valid", |
|
) |
|
if args.num_att_plot != 0: |
|
plot_attention_iter_factory = cls.build_iter_factory( |
|
args=args, |
|
distributed_option=distributed_option, |
|
mode="plot_att", |
|
) |
|
else: |
|
plot_attention_iter_factory = None |
|
|
|
|
|
if args.use_wandb: |
|
if ( |
|
not distributed_option.distributed |
|
or distributed_option.dist_rank == 0 |
|
): |
|
if args.wandb_project is None: |
|
project = ( |
|
"ESPnet_" |
|
+ cls.__name__ |
|
+ str(Path(".").resolve()).replace("/", "_") |
|
) |
|
else: |
|
project = args.wandb_project |
|
if args.wandb_id is None: |
|
wandb_id = str(output_dir).replace("/", "_") |
|
else: |
|
wandb_id = args.wandb_id |
|
|
|
wandb.init( |
|
project=project, |
|
dir=output_dir, |
|
id=wandb_id, |
|
resume="allow", |
|
) |
|
wandb.config.update(args) |
|
else: |
|
|
|
|
|
|
|
args.use_wandb = False |
|
|
|
|
|
|
|
trainer_options = cls.trainer.build_options(args) |
|
cls.trainer.run( |
|
model=model, |
|
optimizers=optimizers, |
|
schedulers=schedulers, |
|
train_iter_factory=train_iter_factory, |
|
valid_iter_factory=valid_iter_factory, |
|
plot_attention_iter_factory=plot_attention_iter_factory, |
|
trainer_options=trainer_options, |
|
distributed_option=distributed_option, |
|
) |
|
|
|
@classmethod |
|
def build_iter_options( |
|
cls, |
|
args: argparse.Namespace, |
|
distributed_option: DistributedOption, |
|
mode: str, |
|
): |
|
if mode == "train": |
|
preprocess_fn = cls.build_preprocess_fn(args, train=True) |
|
collate_fn = cls.build_collate_fn(args, train=True) |
|
data_path_and_name_and_type = args.train_data_path_and_name_and_type |
|
shape_files = args.train_shape_file |
|
batch_size = args.batch_size |
|
batch_bins = args.batch_bins |
|
batch_type = args.batch_type |
|
max_cache_size = args.max_cache_size |
|
max_cache_fd = args.max_cache_fd |
|
distributed = distributed_option.distributed |
|
num_batches = None |
|
num_iters_per_epoch = args.num_iters_per_epoch |
|
train = True |
|
|
|
elif mode == "valid": |
|
preprocess_fn = cls.build_preprocess_fn(args, train=False) |
|
collate_fn = cls.build_collate_fn(args, train=False) |
|
data_path_and_name_and_type = args.valid_data_path_and_name_and_type |
|
shape_files = args.valid_shape_file |
|
|
|
if args.valid_batch_type is None: |
|
batch_type = args.batch_type |
|
else: |
|
batch_type = args.valid_batch_type |
|
if args.valid_batch_size is None: |
|
batch_size = args.batch_size |
|
else: |
|
batch_size = args.valid_batch_size |
|
if args.valid_batch_bins is None: |
|
batch_bins = args.batch_bins |
|
else: |
|
batch_bins = args.valid_batch_bins |
|
if args.valid_max_cache_size is None: |
|
|
|
max_cache_size = 0.05 * args.max_cache_size |
|
else: |
|
max_cache_size = args.valid_max_cache_size |
|
max_cache_fd = args.max_cache_fd |
|
distributed = distributed_option.distributed |
|
num_batches = None |
|
num_iters_per_epoch = None |
|
train = False |
|
|
|
elif mode == "plot_att": |
|
preprocess_fn = cls.build_preprocess_fn(args, train=False) |
|
collate_fn = cls.build_collate_fn(args, train=False) |
|
data_path_and_name_and_type = args.valid_data_path_and_name_and_type |
|
shape_files = args.valid_shape_file |
|
batch_type = "unsorted" |
|
batch_size = 1 |
|
batch_bins = 0 |
|
num_batches = args.num_att_plot |
|
max_cache_fd = args.max_cache_fd |
|
|
|
max_cache_size = np.inf if args.max_cache_size != 0.0 else 0.0 |
|
|
|
distributed = False |
|
num_iters_per_epoch = None |
|
train = False |
|
else: |
|
raise NotImplementedError(f"mode={mode}") |
|
|
|
return IteratorOptions( |
|
preprocess_fn=preprocess_fn, |
|
collate_fn=collate_fn, |
|
data_path_and_name_and_type=data_path_and_name_and_type, |
|
shape_files=shape_files, |
|
batch_type=batch_type, |
|
batch_size=batch_size, |
|
batch_bins=batch_bins, |
|
num_batches=num_batches, |
|
max_cache_size=max_cache_size, |
|
max_cache_fd=max_cache_fd, |
|
distributed=distributed, |
|
num_iters_per_epoch=num_iters_per_epoch, |
|
train=train, |
|
) |
|
|
|
@classmethod |
|
def build_iter_factory( |
|
cls, |
|
args: argparse.Namespace, |
|
distributed_option: DistributedOption, |
|
mode: str, |
|
kwargs: dict = None, |
|
) -> AbsIterFactory: |
|
"""Build a factory object of mini-batch iterator. |
|
|
|
This object is invoked at every epochs to build the iterator for each epoch |
|
as following: |
|
|
|
>>> iter_factory = cls.build_iter_factory(...) |
|
>>> for epoch in range(1, max_epoch): |
|
... for keys, batch in iter_fatory.build_iter(epoch): |
|
... model(**batch) |
|
|
|
The mini-batches for each epochs are fully controlled by this class. |
|
Note that the random seed used for shuffling is decided as "seed + epoch" and |
|
the generated mini-batches can be reproduces when resuming. |
|
|
|
Note that the definition of "epoch" doesn't always indicate |
|
to run out of the whole training corpus. |
|
"--num_iters_per_epoch" option restricts the number of iterations for each epoch |
|
and the rest of samples for the originally epoch are left for the next epoch. |
|
e.g. If The number of mini-batches equals to 4, the following two are same: |
|
|
|
- 1 epoch without "--num_iters_per_epoch" |
|
- 4 epoch with "--num_iters_per_epoch" == 4 |
|
|
|
""" |
|
assert check_argument_types() |
|
iter_options = cls.build_iter_options(args, distributed_option, mode) |
|
|
|
|
|
if kwargs is not None: |
|
for k, v in kwargs.items(): |
|
setattr(iter_options, k, v) |
|
|
|
if args.iterator_type == "sequence": |
|
return cls.build_sequence_iter_factory( |
|
args=args, |
|
iter_options=iter_options, |
|
mode=mode, |
|
) |
|
elif args.iterator_type == "chunk": |
|
return cls.build_chunk_iter_factory( |
|
args=args, |
|
iter_options=iter_options, |
|
mode=mode, |
|
) |
|
elif args.iterator_type == "task": |
|
return cls.build_task_iter_factory( |
|
args=args, |
|
iter_options=iter_options, |
|
mode=mode, |
|
) |
|
else: |
|
raise RuntimeError(f"Not supported: iterator_type={args.iterator_type}") |
|
|
|
@classmethod |
|
def build_sequence_iter_factory( |
|
cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str |
|
) -> AbsIterFactory: |
|
assert check_argument_types() |
|
|
|
dataset = ESPnetDataset( |
|
iter_options.data_path_and_name_and_type, |
|
float_dtype=args.train_dtype, |
|
preprocess=iter_options.preprocess_fn, |
|
max_cache_size=iter_options.max_cache_size, |
|
max_cache_fd=iter_options.max_cache_fd, |
|
) |
|
cls.check_task_requirements( |
|
dataset, args.allow_variable_data_keys, train=iter_options.train |
|
) |
|
|
|
if Path( |
|
Path(iter_options.data_path_and_name_and_type[0][0]).parent, "utt2category" |
|
).exists(): |
|
utt2category_file = str( |
|
Path( |
|
Path(iter_options.data_path_and_name_and_type[0][0]).parent, |
|
"utt2category", |
|
) |
|
) |
|
else: |
|
utt2category_file = None |
|
batch_sampler = build_batch_sampler( |
|
type=iter_options.batch_type, |
|
shape_files=iter_options.shape_files, |
|
fold_lengths=args.fold_length, |
|
batch_size=iter_options.batch_size, |
|
batch_bins=iter_options.batch_bins, |
|
sort_in_batch=args.sort_in_batch, |
|
sort_batch=args.sort_batch, |
|
drop_last=False, |
|
min_batch_size=torch.distributed.get_world_size() |
|
if iter_options.distributed |
|
else 1, |
|
utt2category_file=utt2category_file, |
|
) |
|
|
|
batches = list(batch_sampler) |
|
if iter_options.num_batches is not None: |
|
batches = batches[: iter_options.num_batches] |
|
|
|
bs_list = [len(batch) for batch in batches] |
|
|
|
logging.info(f"[{mode}] dataset:\n{dataset}") |
|
logging.info(f"[{mode}] Batch sampler: {batch_sampler}") |
|
logging.info( |
|
f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, " |
|
f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}" |
|
) |
|
|
|
if iter_options.distributed: |
|
world_size = torch.distributed.get_world_size() |
|
rank = torch.distributed.get_rank() |
|
for batch in batches: |
|
if len(batch) < world_size: |
|
raise RuntimeError( |
|
f"The batch-size must be equal or more than world_size: " |
|
f"{len(batch)} < {world_size}" |
|
) |
|
batches = [batch[rank::world_size] for batch in batches] |
|
|
|
return SequenceIterFactory( |
|
dataset=dataset, |
|
batches=batches, |
|
seed=args.seed, |
|
num_iters_per_epoch=iter_options.num_iters_per_epoch, |
|
shuffle=iter_options.train, |
|
num_workers=args.num_workers, |
|
collate_fn=iter_options.collate_fn, |
|
pin_memory=args.ngpu > 0, |
|
) |
|
|
|
@classmethod |
|
def build_chunk_iter_factory( |
|
cls, |
|
args: argparse.Namespace, |
|
iter_options: IteratorOptions, |
|
mode: str, |
|
) -> AbsIterFactory: |
|
assert check_argument_types() |
|
|
|
dataset = ESPnetDataset( |
|
iter_options.data_path_and_name_and_type, |
|
float_dtype=args.train_dtype, |
|
preprocess=iter_options.preprocess_fn, |
|
max_cache_size=iter_options.max_cache_size, |
|
max_cache_fd=iter_options.max_cache_fd, |
|
) |
|
cls.check_task_requirements( |
|
dataset, args.allow_variable_data_keys, train=iter_options.train |
|
) |
|
|
|
if len(iter_options.shape_files) == 0: |
|
key_file = iter_options.data_path_and_name_and_type[0][0] |
|
else: |
|
key_file = iter_options.shape_files[0] |
|
|
|
batch_sampler = UnsortedBatchSampler(batch_size=1, key_file=key_file) |
|
batches = list(batch_sampler) |
|
if iter_options.num_batches is not None: |
|
batches = batches[: iter_options.num_batches] |
|
logging.info(f"[{mode}] dataset:\n{dataset}") |
|
|
|
if iter_options.distributed: |
|
world_size = torch.distributed.get_world_size() |
|
rank = torch.distributed.get_rank() |
|
if len(batches) < world_size: |
|
raise RuntimeError("Number of samples is smaller than world_size") |
|
if iter_options.batch_size < world_size: |
|
raise RuntimeError("batch_size must be equal or more than world_size") |
|
|
|
if rank < iter_options.batch_size % world_size: |
|
batch_size = iter_options.batch_size // world_size + 1 |
|
else: |
|
batch_size = iter_options.batch_size // world_size |
|
num_cache_chunks = args.num_cache_chunks // world_size |
|
|
|
|
|
|
|
|
|
|
|
batches = batches[rank::world_size] |
|
else: |
|
batch_size = iter_options.batch_size |
|
num_cache_chunks = args.num_cache_chunks |
|
|
|
return ChunkIterFactory( |
|
dataset=dataset, |
|
batches=batches, |
|
seed=args.seed, |
|
batch_size=batch_size, |
|
|
|
|
|
|
|
num_samples_per_epoch=iter_options.num_iters_per_epoch, |
|
shuffle=iter_options.train, |
|
num_workers=args.num_workers, |
|
collate_fn=iter_options.collate_fn, |
|
pin_memory=args.ngpu > 0, |
|
chunk_length=args.chunk_length, |
|
chunk_shift_ratio=args.chunk_shift_ratio, |
|
num_cache_chunks=num_cache_chunks, |
|
) |
|
|
|
|
|
@classmethod |
|
def build_task_iter_factory( |
|
cls, |
|
args: argparse.Namespace, |
|
iter_options: IteratorOptions, |
|
mode: str, |
|
) -> AbsIterFactory: |
|
"""Build task specific iterator factory |
|
|
|
Example: |
|
|
|
>>> class YourTask(AbsTask): |
|
... @classmethod |
|
... def add_task_arguments(cls, parser: argparse.ArgumentParser): |
|
... parser.set_defaults(iterator_type="task") |
|
... |
|
... @classmethod |
|
... def build_task_iter_factory( |
|
... cls, |
|
... args: argparse.Namespace, |
|
... iter_options: IteratorOptions, |
|
... mode: str, |
|
... ): |
|
... return FooIterFactory(...) |
|
... |
|
... @classmethod |
|
... def build_iter_options( |
|
.... args: argparse.Namespace, |
|
... distributed_option: DistributedOption, |
|
... mode: str |
|
... ): |
|
... # if you need to customize options object |
|
""" |
|
raise NotImplementedError |
|
|
|
@classmethod |
|
def build_multiple_iter_factory( |
|
cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str |
|
): |
|
assert check_argument_types() |
|
iter_options = cls.build_iter_options(args, distributed_option, mode) |
|
assert len(iter_options.data_path_and_name_and_type) > 0, len( |
|
iter_options.data_path_and_name_and_type |
|
) |
|
|
|
|
|
num_splits = None |
|
for path in [ |
|
path for path, _, _ in iter_options.data_path_and_name_and_type |
|
] + list(iter_options.shape_files): |
|
if not Path(path).is_dir(): |
|
raise RuntimeError(f"{path} is not a directory") |
|
p = Path(path) / "num_splits" |
|
if not p.exists(): |
|
raise FileNotFoundError(f"{p} is not found") |
|
with p.open() as f: |
|
_num_splits = int(f.read()) |
|
if num_splits is not None and num_splits != _num_splits: |
|
raise RuntimeError( |
|
f"Number of splits are mismathed: " |
|
f"{iter_options.data_path_and_name_and_type[0][0]} and {path}" |
|
) |
|
num_splits = _num_splits |
|
|
|
for i in range(num_splits): |
|
p = Path(path) / f"split.{i}" |
|
if not p.exists(): |
|
raise FileNotFoundError(f"{p} is not found") |
|
|
|
|
|
data_path_and_name_and_type_list = [ |
|
[ |
|
(str(Path(p) / f"split.{i}"), n, t) |
|
for p, n, t in iter_options.data_path_and_name_and_type |
|
] |
|
for i in range(num_splits) |
|
] |
|
shape_files_list = [ |
|
[str(Path(s) / f"split.{i}") for s in iter_options.shape_files] |
|
for i in range(num_splits) |
|
] |
|
num_iters_per_epoch_list = [ |
|
(iter_options.num_iters_per_epoch + i) // num_splits |
|
if iter_options.num_iters_per_epoch is not None |
|
else None |
|
for i in range(num_splits) |
|
] |
|
max_cache_size = iter_options.max_cache_size / num_splits |
|
|
|
|
|
build_funcs = [ |
|
functools.partial( |
|
cls.build_iter_factory, |
|
args, |
|
distributed_option, |
|
mode, |
|
kwargs=dict( |
|
data_path_and_name_and_type=_data_path_and_name_and_type, |
|
shape_files=_shape_files, |
|
num_iters_per_epoch=_num_iters_per_epoch, |
|
max_cache_size=max_cache_size, |
|
), |
|
) |
|
for ( |
|
_data_path_and_name_and_type, |
|
_shape_files, |
|
_num_iters_per_epoch, |
|
) in zip( |
|
data_path_and_name_and_type_list, |
|
shape_files_list, |
|
num_iters_per_epoch_list, |
|
) |
|
] |
|
|
|
|
|
return MultipleIterFactory( |
|
build_funcs=build_funcs, shuffle=iter_options.train, seed=args.seed |
|
) |
|
|
|
@classmethod |
|
def build_streaming_iterator( |
|
cls, |
|
data_path_and_name_and_type, |
|
preprocess_fn, |
|
collate_fn, |
|
key_file: str = None, |
|
batch_size: int = 1, |
|
dtype: str = np.float32, |
|
num_workers: int = 1, |
|
allow_variable_data_keys: bool = False, |
|
ngpu: int = 0, |
|
inference: bool = False, |
|
) -> DataLoader: |
|
"""Build DataLoader using iterable dataset""" |
|
assert check_argument_types() |
|
|
|
if collate_fn is not None: |
|
kwargs = dict(collate_fn=collate_fn) |
|
else: |
|
kwargs = {} |
|
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.2"): |
|
dataset = IterableESPnetDataset( |
|
data_path_and_name_and_type, |
|
float_dtype=dtype, |
|
preprocess=preprocess_fn, |
|
key_file=key_file, |
|
) |
|
if dataset.apply_utt2category: |
|
kwargs.update(batch_size=1) |
|
else: |
|
kwargs.update(batch_size=batch_size) |
|
else: |
|
dataset = ESPnetDataset( |
|
data_path_and_name_and_type, |
|
float_dtype=dtype, |
|
preprocess=preprocess_fn, |
|
) |
|
if key_file is None: |
|
key_file = data_path_and_name_and_type[0][0] |
|
batch_sampler = UnsortedBatchSampler( |
|
batch_size=batch_size, |
|
key_file=key_file, |
|
drop_last=False, |
|
) |
|
kwargs.update(batch_sampler=batch_sampler) |
|
|
|
cls.check_task_requirements( |
|
dataset, allow_variable_data_keys, train=False, inference=inference |
|
) |
|
|
|
return DataLoader( |
|
dataset=dataset, |
|
pin_memory=ngpu > 0, |
|
num_workers=num_workers, |
|
**kwargs, |
|
) |
|
|
|
|
|
@classmethod |
|
def build_model_from_file( |
|
cls, |
|
config_file: Union[Path, str], |
|
model_file: Union[Path, str] = None, |
|
device: str = "cpu", |
|
) -> Tuple[AbsESPnetModel, argparse.Namespace]: |
|
"""This method is used for inference or fine-tuning. |
|
|
|
Args: |
|
config_file: The yaml file saved when training. |
|
model_file: The model file saved when training. |
|
device: |
|
|
|
""" |
|
assert check_argument_types() |
|
config_file = Path(config_file) |
|
|
|
with config_file.open("r", encoding="utf-8") as f: |
|
args = yaml.safe_load(f) |
|
args = argparse.Namespace(**args) |
|
model = cls.build_model(args) |
|
if not isinstance(model, AbsESPnetModel): |
|
raise RuntimeError( |
|
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" |
|
) |
|
model.to(device) |
|
if model_file is not None: |
|
if device == "cuda": |
|
|
|
|
|
device = f"cuda:{torch.cuda.current_device()}" |
|
model.load_state_dict(torch.load(model_file, map_location=device), strict=False) |
|
|
|
return model, args |
|
|