NeMo / nemo /utils /arguments.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
from typing import Any, Dict, List, Optional, Union
def add_optimizer_args(
parent_parser: ArgumentParser,
optimizer: str = 'adam',
default_lr: float = None,
default_opt_args: Optional[Union[Dict[str, Any], List[str]]] = None,
) -> ArgumentParser:
"""Extends existing argparse with support for optimizers.
# Example of adding optimizer args to command line :
python train_script.py ... --optimizer "novograd" --lr 0.01 \
--opt_args betas=0.95,0.5 weight_decay=0.001
Args:
parent_parser (ArgumentParser): Custom CLI parser that will be extended.
optimizer (str): Default optimizer required.
default_lr (float): Default learning rate that should be overriden during training.
default_opt_args (list(str)): List of overriding arguments for the instantiated optimizer.
Returns:
ArgumentParser: Parser extended by Optimizers arguments.
"""
if default_opt_args is None:
default_opt_args = []
parser = ArgumentParser(parents=[parent_parser], add_help=True, conflict_handler='resolve')
parser.add_argument('--optimizer', type=str, default=optimizer, help='Name of the optimizer. Defaults to Adam.')
parser.add_argument('--lr', type=float, default=default_lr, help='Learning rate of the optimizer.')
parser.add_argument(
'--opt_args',
default=default_opt_args,
nargs='+',
type=str,
help='Overriding arguments for the optimizer. \n Must follow the pattern : \n name=value separated by spaces.'
'Example: --opt_args weight_decay=0.001 eps=1e-8 betas=0.9,0.999',
)
return parser
def add_scheduler_args(parent_parser: ArgumentParser) -> ArgumentParser:
"""Extends existing argparse with default LR scheduler args.
Args:
parent_parser (ArgumentParser): Custom CLI parser that will be extended.
Returns:
ArgumentParser: Parser extended by LR Scheduler arguments.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve')
parser.add_argument("--warmup_steps", type=int, required=False, default=None, help="Number of warmup steps")
parser.add_argument(
"--warmup_ratio",
type=float,
required=False,
default=None,
help="Number of warmup steps as a percentage of total training steps",
)
parser.add_argument("--hold_steps", type=int, required=False, default=None, help="Number of hold LR steps")
parser.add_argument(
"--hold_ratio",
type=float,
required=False,
default=None,
help="Number of hold LR steps as a percentage of total training steps",
)
parser.add_argument("--min_lr", type=float, required=False, default=0.0, help="Minimum learning rate")
parser.add_argument(
"--last_epoch", type=int, required=False, default=-1, help="Last epoch id. -1 indicates training from scratch"
)
return parser
def add_asr_args(parent_parser: ArgumentParser) -> ArgumentParser:
"""Extends existing argparse with default ASR collection args.
Args:
parent_parser (ArgumentParser): Custom CLI parser that will be extended.
Returns:
ArgumentParser: Parser extended by NeMo ASR Collection arguments.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve')
parser.add_argument("--asr_model", type=str, required=True, default="bad_quartznet15x5.yaml", help="")
parser.add_argument("--train_dataset", type=str, required=True, default=None, help="training dataset path")
parser.add_argument("--eval_dataset", type=str, required=True, help="evaluation dataset path")
return parser
def add_nlp_args(parent_parser: ArgumentParser) -> ArgumentParser:
"""Extends existing argparse with default NLP collection args.
Args:
parent_parser (ArgumentParser): Custom CLI parser that will be extended.
Returns:
ArgumentParser: Parser extended by NeMo NLP Collection arguments.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, conflict_handler='resolve')
parser.add_argument(
"--data_dir", type=str, required=False, help="data directory to training or/and evaluation dataset"
)
parser.add_argument(
"--config_file", type=str, required=False, default=None, help="Huggingface model configuration file"
)
parser.add_argument(
"--pretrained_model_name", default='bert-base-uncased', type=str, required=False, help="pretrained model name"
)
parser.add_argument(
"--tokenizer_name", default='nemobert', type=str, choices=['sentencepiece', 'nemobert'], help="Tokenizer type"
)
parser.add_argument("--tokenizer_model", default=None, type=str, help="Tokenizer file for sentence piece")
parser.add_argument("--do_lower_case", action='store_true', required=False, help="lower case data")
return parser