#!/usr/bin/env python3 -u # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Train a new model on one or across multiple GPUs. """ import argparse import logging import math import os import sys from typing import Dict, Optional, Any, List, Tuple, Callable # We need to setup root logger before importing any fairseq libraries. logging.basicConfig( format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s', datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), stream=sys.stdout, ) logger = logging.getLogger("fairseq_cli.train") import numpy as np import torch from fairseq import ( # checkpoint_utils, options, quantization_utils, tasks, utils, ) from fairseq.data import iterators from fairseq.data.plasma_utils import PlasmaStore from fairseq.dataclass.configs import FairseqConfig from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils from fairseq.file_io import PathManager from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer # from fairseq.trainer import Trainer from omegaconf import DictConfig, OmegaConf from utils import checkpoint_utils from trainer import Trainer def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() if cfg.common.log_file is not None: handler = logging.FileHandler(filename=cfg.common.log_file) logger.addHandler(handler) np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) if distributed_utils.is_master(cfg.distributed_training): checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args logger.info(cfg) if cfg.checkpoint.write_checkpoints_asynchronously: try: import iopath # noqa: F401 except ImportError: logging.exception( "Asynchronous checkpoint writing is specified but iopath is " "not installed: `pip install iopath`" ) return # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion if cfg.distributed_training.ddp_backend == "fully_sharded": with fsdp_enable_wrap(cfg.distributed_training): model = fsdp_wrap(task.build_model(cfg.model)) else: model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. shared model params: {:,} (num. trained: {:,})".format( sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)), sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad) ) ) logger.info( "num. expert model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)), sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad), ) ) # Load valid dataset (we load training data below, based on the latest checkpoint) # We load the valid dataset AFTER building the model # data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg) if cfg.dataset.combine_valid_subsets: task.load_dataset("valid", combine=True, epoch=1) else: for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=cfg.common.quantization_config_path, max_epoch=cfg.optimization.max_epoch, max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer if cfg.common.model_parallel_size == 1: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) logger.info( "training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size ) ) logger.info( "max tokens per device = {} and max sentences per device = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, ) ) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) if cfg.common.tpu: import torch_xla.core.xla_model as xm xm.rendezvous("load_checkpoint") # wait for all workers max_epoch = cfg.optimization.max_epoch or math.inf if max_epoch > 0: num_iter_per_epoch = (len(epoch_itr) + cfg.distributed_training.distributed_world_size - 1) \ // cfg.distributed_training.distributed_world_size trainer.lr_reinit(num_iter_per_epoch * max_epoch, trainer.get_num_updates()) lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while epoch_itr.next_epoch_idx <= max_epoch: if lr <= cfg.optimization.stop_min_lr: logger.info( f"stopping training because current learning rate ({lr}) is smaller " "than or equal to minimum learning rate " f"(--stop-min-lr={cfg.optimization.stop_min_lr})" ) break # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=True, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) # ioPath implementation to wait for all asynchronous file writes to complete. if cfg.checkpoint.write_checkpoints_asynchronously: logger.info( "ioPath PathManager waiting for all asynchronous checkpoint " "writes to finish." ) PathManager.async_close() logger.info("ioPath PathManager finished waiting.") def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool: # skip check if no validation was done in the current epoch if valid_loss is None: return False if cfg.checkpoint.patience <= 0: return False def is_better(a, b): return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b prev_best = getattr(should_stop_early, "best", None) if prev_best is None or is_better(valid_loss, prev_best): should_stop_early.best = valid_loss should_stop_early.num_runs = 0 return False else: should_stop_early.num_runs += 1 if should_stop_early.num_runs >= cfg.checkpoint.patience: logger.info( "early stop since valid performance hasn't improved for last {} runs".format( cfg.checkpoint.patience ) ) return True else: return False @metrics.aggregate("train") def train( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr ) -> Tuple[List[Optional[float]], bool]: """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum), ) update_freq = ( cfg.optimization.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(cfg.optimization.update_freq) else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_file=cfg.common.log_file, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), azureml_logging=( cfg.common.azureml_logging if distributed_utils.is_master(cfg.distributed_training) else False ), ) progress.update_config(_flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = cfg.dataset.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() logger.info("Start iterating over samples") for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % cfg.common.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop def _flatten_config(cfg: DictConfig): config = OmegaConf.to_container(cfg) # remove any legacy Namespaces and replace with a single "args" namespace = None for k, v in list(config.items()): if isinstance(v, argparse.Namespace): namespace = v del config[k] if namespace is not None: config["args"] = vars(namespace) return config def validate_and_save( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, valid_subsets: List[str], end_of_epoch: bool, ) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf # Stopping conditions (and an additional one based on validation loss later # on) should_stop = False if num_updates >= max_update: should_stop = True logger.info( f"Stopping training due to " f"num_updates: {num_updates} >= max_update: {max_update}" ) training_time_hours = trainer.cumulative_training_time() / (60 * 60) if ( cfg.optimization.stop_time_hours > 0 and training_time_hours > cfg.optimization.stop_time_hours ): should_stop = True logger.info( f"Stopping training due to " f"cumulative_training_time: {training_time_hours} > " f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)" ) do_save = ( (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) or should_stop or ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 and num_updates % cfg.checkpoint.save_interval_updates == 0 and num_updates >= cfg.dataset.validate_after_updates ) ) do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) or should_stop or ( cfg.dataset.validate_interval_updates > 0 and num_updates > 0 and num_updates % cfg.dataset.validate_interval_updates == 0 ) ) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates # Validate valid_losses = [None] if do_validate: valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) should_stop |= should_stop_early(cfg, valid_losses[0]) # Save checkpoint if do_save or should_stop: checkpoint_utils.save_checkpoint( cfg.checkpoint, trainer, epoch_itr, valid_losses[0] ) return valid_losses, should_stop def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]: stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0) return stats def validate( cfg: DictConfig, trainer: Trainer, task: tasks.FairseqTask, epoch_itr, subsets: List[str], ) -> List[Optional[float]]: """Evaluate the model on the validation set(s) and return the losses.""" if cfg.dataset.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(cfg.dataset.fixed_validation_seed) trainer.begin_valid_epoch(epoch_itr.epoch) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr( shuffle=False, set_dataset_epoch=False # use a fixed valid set ) if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( cfg.common.tensorboard_logdir if distributed_utils.is_master(cfg.distributed_training) else None ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None ), wandb_run_name=os.environ.get( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for i, sample in enumerate(progress): if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps: break trainer.valid_step(sample) # log validation stats if hasattr(task, 'get_valid_stats'): stats = task.get_valid_stats(cfg, trainer, agg.get_smoothed_values()) else: stats = agg.get_smoothed_values() stats = get_valid_stats(cfg, trainer, stats) if hasattr(task, "post_validate"): task.post_validate(trainer.get_model(), stats, agg) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric]) return valid_losses def get_valid_stats( cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any] ) -> Dict[str, Any]: stats["num_updates"] = trainer.get_num_updates() if hasattr(checkpoint_utils.save_checkpoint, "best"): key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric) best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min stats[key] = best_function( checkpoint_utils.save_checkpoint.best, stats[cfg.checkpoint.best_checkpoint_metric], ) return stats def cli_main( modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None ) -> None: parser = options.get_training_parser() args = options.parse_args_and_arch(parser, modify_parser=modify_parser) cfg = convert_namespace_to_omegaconf(args) if cfg.common.use_plasma_view: server = PlasmaStore(path=cfg.common.plasma_path) logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}") if args.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): distributed_utils.call_main(cfg, main) else: distributed_utils.call_main(cfg, main) # if cfg.common.use_plasma_view: # server.server.kill() if __name__ == "__main__": cli_main()