import argparse import os import random import sys from loguru import logger import numpy as np import torch from virtex.config import Config import virtex.utils.distributed as dist def cycle(dataloader, device, start_iteration: int = 0): r""" A generator to yield batches of data from dataloader infinitely. Internally, it sets the ``epoch`` for dataloader sampler to shuffle the examples. One may optionally provide the starting iteration to make sure the shuffling seed is different and continues naturally. """ iteration = start_iteration while True: if isinstance(dataloader.sampler, # Set the `epoch` of DistributedSampler as current iteration. This # is a way of determinisitic shuffling after every epoch, so it is # just a seed and need not necessarily be the "epoch"."Beginning new epoch, setting shuffle seed {iteration}") dataloader.sampler.set_epoch(iteration) for batch in dataloader: for key in batch: batch[key] = batch[key].to(device) yield batch iteration += 1 def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pretrain"): r""" Setup common stuff at the start of every pretraining or downstream evaluation job, all listed here to avoid code duplication. Basic steps: 1. Fix random seeds and other PyTorch flags. 2. Set up a serialization directory and loggers. 3. Log important stuff such as config, process info (useful during distributed training). 4. Save a copy of config to serialization directory. .. note:: It is assumed that multiple processes for distributed training have already been launched from outside. Functions from :mod:`virtex.utils.distributed` module ae used to get process info. Parameters ---------- _C: virtex.config.Config Config object with all the parameters. _A: argparse.Namespace Command line arguments. job_type: str, optional (default = "pretrain") Type of job for which setup is to be done; one of ``{"pretrain", "downstream"}``. """ # Get process rank and world size (assuming distributed is initialized). RANK = dist.get_rank() WORLD_SIZE = dist.get_world_size() # For reproducibility - refer torch.manual_seed(_C.RANDOM_SEED) torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK random.seed(_C.RANDOM_SEED) np.random.seed(_C.RANDOM_SEED) # Create serialization directory and save config in it. os.makedirs(_A.serialization_dir, exist_ok=True) _C.dump(os.path.join(_A.serialization_dir, f"{job_type}_config.yaml")) # Remove default logger, create a logger for each process which writes to a # separate log-file. This makes changes in global scope. logger.remove(0) if dist.get_world_size() > 1: logger.add( os.path.join(_A.serialization_dir, f"log-rank{RANK}.txt"), format="{time} {level} {message}", ) # Add a logger for stdout only for the master process. if dist.is_master_process(): logger.add( sys.stdout, format="{time}: {message}", colorize=True ) # Print process info, config and args."Rank of current process: {RANK}. World size: {WORLD_SIZE}")"Command line args:") for arg in vars(_A):"{:<20}: {}".format(arg, getattr(_A, arg))) def common_parser(description: str = "") -> argparse.ArgumentParser: r""" Create an argument parser some common arguments useful for any pretraining or downstream evaluation scripts. Parameters ---------- description: str, optional (default = "") Description to be used with the argument parser. Returns ------- argparse.ArgumentParser A parser object with added arguments. """ parser = argparse.ArgumentParser(description=description) # fmt: off parser.add_argument( "--config", metavar="FILE", help="Path to a pretraining config file." ) parser.add_argument( "--config-override", nargs="*", default=[], help="A list of key-value pairs to modify pretraining config params.", ) parser.add_argument( "--serialization-dir", default="/tmp/virtex", help="Path to a directory to serialize checkpoints and save job logs." ) group = parser.add_argument_group("Compute resource management arguments.") group.add_argument( "--cpu-workers", type=int, default=0, help="Number of CPU workers per GPU to use for data loading.", ) group.add_argument( "--num-machines", type=int, default=1, help="Number of machines used in distributed training." ) group.add_argument( "--num-gpus-per-machine", type=int, default=0, help="""Number of GPUs per machine with IDs as (0, 1, 2 ...). Set as zero for single-process CPU training.""", ) group.add_argument( "--machine-rank", type=int, default=0, help="""Rank of the machine, integer in [0, num_machines). Default 0 for training with a single machine.""", ) group.add_argument( "--dist-url", default=f"tcp://", help="""URL of the master process in distributed training, it defaults to localhost for single-machine training.""", ) # fmt: on return parser