Source code for transformers.training_args

# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import json
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .trainer_utils import EvaluationStrategy
from .utils import logging


if is_torch_available():
    import torch

if is_torch_tpu_available():
    import torch_xla.core.xla_model as xm


logger = logging.get_logger(__name__)


def default_logdir() -> str:
    """
    Same default as PyTorch
    """
    import socket
    from datetime import datetime

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    return os.path.join("runs", current_time + "_" + socket.gethostname())


[docs]@dataclass class TrainingArguments: """ TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop itself**. Using :class:`~transformers.HfArgumentParser` we can turn this class into `argparse <https://docs.python.org/3/library/argparse.html#module-argparse>`__ arguments that can be specified on the command line. Parameters: output_dir (:obj:`str`): The output directory where the model predictions and checkpoints will be written. overwrite_output_dir (:obj:`bool`, `optional`, defaults to :obj:`False`): If :obj:`True`, overwrite the content of the output directory. Use this to continue training if :obj:`output_dir` points to a checkpoint directory. do_train (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to run training or not. This argument is not directly used by :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more details. do_eval (:obj:`bool`, `optional`): Whether to run evaluation on the validation set or not. Will be set to :obj:`True` if :obj:`evaluation_strategy` is different from :obj:`"no"`. This argument is not directly used by :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more details. do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to run predictions on the test set or not. This argument is not directly used by :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more details. evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): The evaluation strategy to adopt during training. Possible values are: * :obj:`"no"`: No evaluation is done during training. * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`. * :obj:`"epoch"`: Evaluation is done at the end of each epoch. prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`): When performing evaluation and generating predictions, only returns the loss. per_device_train_batch_size (:obj:`int`, `optional`, defaults to 8): The batch size per GPU/TPU core/CPU for training. per_device_eval_batch_size (:obj:`int`, `optional`, defaults to 8): The batch size per GPU/TPU core/CPU for evaluation. gradient_accumulation_steps (:obj:`int`, `optional`, defaults to 1): Number of updates steps to accumulate the gradients for, before performing a backward/update pass. .. warning:: When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, evaluation, save will be conducted every ``gradient_accumulation_steps * xxx_step`` training examples. eval_accumulation_steps (:obj:`int`, `optional`): Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster but requires more memory). learning_rate (:obj:`float`, `optional`, defaults to 5e-5): The initial learning rate for Adam. weight_decay (:obj:`float`, `optional`, defaults to 0): The weight decay to apply (if not zero). adam_beta1 (:obj:`float`, `optional`, defaults to 0.9): The beta1 hyperparameter for the Adam optimizer. adam_beta2 (:obj:`float`, `optional`, defaults to 0.999): The beta2 hyperparameter for the Adam optimizer. adam_epsilon (:obj:`float`, `optional`, defaults to 1e-8): The epsilon hyperparameter for the Adam optimizer. max_grad_norm (:obj:`float`, `optional`, defaults to 1.0): Maximum gradient norm (for gradient clipping). num_train_epochs(:obj:`float`, `optional`, defaults to 3.0): Total number of training epochs to perform (if not an integer, will perform the decimal part percents of the last epoch before stopping training). max_steps (:obj:`int`, `optional`, defaults to -1): If set to a positive number, the total number of training steps to perform. Overrides :obj:`num_train_epochs`. warmup_steps (:obj:`int`, `optional`, defaults to 0): Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. logging_dir (:obj:`str`, `optional`): `TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to `runs/**CURRENT_DATETIME_HOSTNAME**`. logging_first_step (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to log and evaluate the first :obj:`global_step` or not. logging_steps (:obj:`int`, `optional`, defaults to 500): Number of update steps between two logs. save_steps (:obj:`int`, `optional`, defaults to 500): Number of updates steps before two checkpoint saves. save_total_limit (:obj:`int`, `optional`): If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in :obj:`output_dir`. no_cuda (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to not use CUDA even when it is available or not. seed (:obj:`int`, `optional`, defaults to 42): Random seed for initialization. fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training. fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'): For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__. local_rank (:obj:`int`, `optional`, defaults to -1): Rank of the process during distributed training. tpu_num_cores (:obj:`int`, `optional`): When training on TPU, the number of TPU cores (automatically passed by launcher script). debug (:obj:`bool`, `optional`, defaults to :obj:`False`): When training on TPU, whether to print debug metrics or not. dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not. eval_steps (:obj:`int`, `optional`): Number of update steps between two evaluations if :obj:`evaluation_strategy="steps"`. Will default to the same value as :obj:`logging_steps` if not set. dataloader_num_workers (:obj:`int`, `optional`, defaults to 0): Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process. past_index (:obj:`int`, `optional`, defaults to -1): Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can make use of the past hidden states for their predictions. If this argument is set to a positive int, the ``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model at the next training step under the keyword argument ``mems``. run_name (:obj:`str`, `optional`): A descriptor for the run. Typically used for `wandb <https://www.wandb.com/>`_ logging. disable_tqdm (:obj:`bool`, `optional`): Whether or not to disable the tqdm progress bars and table of metrics produced by :class:`~transformers.notebook.NotebookTrainingTracker` in Jupyter Notebooks. Will default to :obj:`True` if the logging level is set to warn or lower (default), :obj:`False` otherwise. remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`): If using :obj:`datasets.Dataset` datasets, whether or not to automatically remove the columns unused by the model forward method. (Note that this behavior is not implemented for :class:`~transformers.TFTrainer` yet.) label_names (:obj:`List[str]`, `optional`): The list of keys in your dictionary of inputs that correspond to the labels. Will eventually default to :obj:`["labels"]` except if the model used is one of the :obj:`XxxForQuestionAnswering` in which case it will default to :obj:`["start_positions", "end_positions"]`. load_best_model_at_end (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to load the best model found during training at the end of training. .. note:: When set to :obj:`True`, the parameters :obj:`save_steps` will be ignored and the model will be saved after each evaluation. metric_for_best_model (:obj:`str`, `optional`): Use in conjunction with :obj:`load_best_model_at_end` to specify the metric to use to compare two different models. Must be the name of a metric returned by the evaluation with or without the prefix :obj:`"eval_"`. Will default to :obj:`"loss"` if unspecified and :obj:`load_best_model_at_end=True` (to use the evaluation loss). If you set this value, :obj:`greater_is_better` will default to :obj:`True`. Don't forget to set it to :obj:`False` if your metric is better when lower. greater_is_better (:obj:`bool`, `optional`): Use in conjunction with :obj:`load_best_model_at_end` and :obj:`metric_for_best_model` to specify if better models should have a greater metric or not. Will default to: - :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or :obj:`"eval_loss"`. - :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`. model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`): If there is more than one device, whether to use model parallelism to distribute the model's modules across devices or not. ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`): When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping step can take a long time) but will not yield the same results as the interrupted training would have. fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`): The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or :obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the other choices will force the requested backend. sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed training only). This is an experimental feature. """ output_dir: str = field( metadata={"help": "The output directory where the model predictions and checkpoints will be written."} ) overwrite_output_dir: bool = field( default=False, metadata={ "help": ( "Overwrite the content of the output directory." "Use this to continue training if output_dir points to a checkpoint directory." ) }, ) do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) model_parallel: bool = field( default=False, metadata={ "help": ( "If there are more than one devices, whether to use model parallelism to distribute the " "model's modules across devices." ) }, ) evaluation_strategy: EvaluationStrategy = field( default="no", metadata={"help": "Run evaluation during training at each logging step."}, ) prediction_loss_only: bool = field( default=False, metadata={"help": "When performing evaluation and predictions, only returns the loss."}, ) per_device_train_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} ) per_device_eval_batch_size: int = field( default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} ) per_gpu_train_batch_size: Optional[int] = field( default=None, metadata={ "help": "Deprecated, the use of `--per_device_train_batch_size` is preferred. " "Batch size per GPU/TPU core/CPU for training." }, ) per_gpu_eval_batch_size: Optional[int] = field( default=None, metadata={ "help": "Deprecated, the use of `--per_device_eval_batch_size` is preferred." "Batch size per GPU/TPU core/CPU for evaluation." }, ) gradient_accumulation_steps: int = field( default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, ) eval_accumulation_steps: Optional[int] = field( default=None, metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."}, ) learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."}) adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for Adam optimizer"}) adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for Adam optimizer"}) adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."}) max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) max_steps: int = field( default=-1, metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, ) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."}) logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) save_total_limit: Optional[int] = field( default=None, metadata={ "help": ( "Limit the total amount of checkpoints." "Deletes the older checkpoints in the output_dir. Default is unlimited checkpoints" ) }, ) no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) seed: int = field(default=42, metadata={"help": "random seed for initialization"}) fp16: bool = field( default=False, metadata={"help": "Whether to use 16-bit (mixed) precision (through NVIDIA Apex) instead of 32-bit"}, ) fp16_opt_level: str = field( default="O1", metadata={ "help": ( "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html" ) }, ) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) tpu_num_cores: Optional[int] = field( default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"} ) tpu_metrics_debug: bool = field( default=False, metadata={"help": "Deprecated, the use of `--debug` is preferred. TPU: Whether to print debug metrics"}, ) debug: bool = field(default=False, metadata={"help": "Whether to print debug metrics on TPU"}) dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} ) eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) dataloader_num_workers: int = field( default=0, metadata={ "help": "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process." }, ) past_index: int = field( default=-1, metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."}, ) run_name: Optional[str] = field( default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."} ) disable_tqdm: Optional[bool] = field( default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."} ) remove_unused_columns: Optional[bool] = field( default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."} ) label_names: Optional[List[str]] = field( default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."} ) load_best_model_at_end: Optional[bool] = field( default=False, metadata={"help": "Whether or not to load the best model found during training at the end of training."}, ) metric_for_best_model: Optional[str] = field( default=None, metadata={"help": "The metric to use to compare two different models."} ) greater_is_better: Optional[bool] = field( default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."} ) ignore_data_skip: bool = field( default=False, metadata={ "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." }, ) fp16_backend: str = field( default="auto", metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]}, ) sharded_ddp: bool = field( default=False, metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."}, ) def __post_init__(self): if self.disable_tqdm is None: self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy) if self.do_eval is False and self.evaluation_strategy != EvaluationStrategy.NO: self.do_eval = True if self.eval_steps is None: self.eval_steps = self.logging_steps if self.load_best_model_at_end and self.metric_for_best_model is None: self.metric_for_best_model = "loss" if self.greater_is_better is None and self.metric_for_best_model is not None: self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"] if self.run_name is None: self.run_name = self.output_dir if is_torch_available() and self.device.type != "cuda" and self.fp16: raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.") @property def train_batch_size(self) -> int: """ The actual batch size for training (may differ from :obj:`per_gpu_train_batch_size` in distributed training). """ if self.per_gpu_train_batch_size: logger.warning( "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future " "version. Using `--per_device_train_batch_size` is preferred." ) per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size if not self.model_parallel: train_batch_size = per_device_batch_size * max(1, self.n_gpu) else: train_batch_size = per_device_batch_size return train_batch_size @property def eval_batch_size(self) -> int: """ The actual batch size for evaluation (may differ from :obj:`per_gpu_eval_batch_size` in distributed training). """ if self.per_gpu_eval_batch_size: logger.warning( "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future " "version. Using `--per_device_eval_batch_size` is preferred." ) per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size if not self.model_parallel: eval_batch_size = per_device_batch_size * max(1, self.n_gpu) else: eval_batch_size = per_device_batch_size return eval_batch_size @cached_property @torch_required def _setup_devices(self) -> Tuple["torch.device", int]: logger.info("PyTorch: setting up devices") if self.no_cuda: device = torch.device("cpu") n_gpu = 0 elif is_torch_tpu_available(): device = xm.xla_device() n_gpu = 0 elif self.local_rank == -1: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will # trigger an error that a device index is missing. Index 0 takes into account the # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs torch.distributed.init_process_group(backend="nccl") device = torch.device("cuda", self.local_rank) n_gpu = 1 if device.type == "cuda": torch.cuda.set_device(device) return device, n_gpu @property @torch_required def device(self) -> "torch.device": """ The device used by this process. """ return self._setup_devices[0] @property @torch_required def n_gpu(self): """ The number of GPUs used by this process. Note: This will only be greater than one when you have multiple GPUs available but are not using distributed training. For distributed training, it will always be 1. """ return self._setup_devices[1] @property @torch_required def parallel_mode(self): """ The current mode used for parallelism if multiple GPUs/TPU cores are available. One of: - :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU). - :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`). - :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses :obj:`torch.nn.DistributedDataParallel`). - :obj:`ParallelMode.TPU`: several TPU cores. """ if is_torch_tpu_available(): return ParallelMode.TPU elif self.local_rank != -1: return ParallelMode.DISTRIBUTED elif self.n_gpu > 1: return ParallelMode.NOT_DISTRIBUTED else: return ParallelMode.NOT_PARALLEL
[docs] def to_dict(self): """ Serializes this instance while replace `Enum` by their values (for JSON serialization support). """ d = dataclasses.asdict(self) for k, v in d.items(): if isinstance(v, Enum): d[k] = v.value return d
[docs] def to_json_string(self): """ Serializes this instance to a JSON string. """ return json.dumps(self.to_dict(), indent=2)
[docs] def to_sanitized_dict(self) -> Dict[str, Any]: """ Sanitized serialization to use with TensorBoard’s hparams """ d = self.to_dict() d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}} valid_types = [bool, int, float, str] if is_torch_available(): valid_types.append(torch.Tensor) return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
class ParallelMode(Enum): NOT_PARALLEL = "not_parallel" NOT_DISTRIBUTED = "not_distributed" DISTRIBUTED = "distributed" TPU = "tpu"