import argparse
import os
import random
import sys
from loguru import logger
import numpy as np
import torch
from virtex.config import Config
import virtex.utils.distributed as dist
def cycle(dataloader, device, start_iteration: int = 0):
r"""
A generator to yield batches of data from dataloader infinitely.
Internally, it sets the ``epoch`` for dataloader sampler to shuffle the
examples. One may optionally provide the starting iteration to make sure
the shuffling seed is different and continues naturally.
"""
iteration = start_iteration
while True:
if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler):
# Set the `epoch` of DistributedSampler as current iteration. This
# is a way of determinisitic shuffling after every epoch, so it is
# just a seed and need not necessarily be the "epoch".
logger.info(f"Beginning new epoch, setting shuffle seed {iteration}")
dataloader.sampler.set_epoch(iteration)
for batch in dataloader:
for key in batch:
batch[key] = batch[key].to(device)
yield batch
iteration += 1
def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pretrain"):
r"""
Setup common stuff at the start of every pretraining or downstream
evaluation job, all listed here to avoid code duplication. Basic steps:
1. Fix random seeds and other PyTorch flags.
2. Set up a serialization directory and loggers.
3. Log important stuff such as config, process info (useful during
distributed training).
4. Save a copy of config to serialization directory.
.. note::
It is assumed that multiple processes for distributed training have
already been launched from outside. Functions from
:mod:`virtex.utils.distributed` module ae used to get process info.
Parameters
----------
_C: virtex.config.Config
Config object with all the parameters.
_A: argparse.Namespace
Command line arguments.
job_type: str, optional (default = "pretrain")
Type of job for which setup is to be done; one of ``{"pretrain",
"downstream"}``.
"""
# Get process rank and world size (assuming distributed is initialized).
RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
# For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(_C.RANDOM_SEED)
torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC
torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK
random.seed(_C.RANDOM_SEED)
np.random.seed(_C.RANDOM_SEED)
# Create serialization directory and save config in it.
os.makedirs(_A.serialization_dir, exist_ok=True)
_C.dump(os.path.join(_A.serialization_dir, f"{job_type}_config.yaml"))
# Remove default logger, create a logger for each process which writes to a
# separate log-file. This makes changes in global scope.
logger.remove(0)
if dist.get_world_size() > 1:
logger.add(
os.path.join(_A.serialization_dir, f"log-rank{RANK}.txt"),
format="{time} {level} {message}",
)
# Add a logger for stdout only for the master process.
if dist.is_master_process():
logger.add(
sys.stdout, format="{time}: {message}", colorize=True
)
# Print process info, config and args.
logger.info(f"Rank of current process: {RANK}. World size: {WORLD_SIZE}")
logger.info(str(_C))
logger.info("Command line args:")
for arg in vars(_A):
logger.info("{:<20}: {}".format(arg, getattr(_A, arg)))
def common_parser(description: str = "") -> argparse.ArgumentParser:
r"""
Create an argument parser some common arguments useful for any pretraining
or downstream evaluation scripts.
Parameters
----------
description: str, optional (default = "")
Description to be used with the argument parser.
Returns
-------
argparse.ArgumentParser
A parser object with added arguments.
"""
parser = argparse.ArgumentParser(description=description)
# fmt: off
parser.add_argument(
"--config", metavar="FILE", help="Path to a pretraining config file."
)
parser.add_argument(
"--config-override", nargs="*", default=[],
help="A list of key-value pairs to modify pretraining config params.",
)
parser.add_argument(
"--serialization-dir", default="/tmp/virtex",
help="Path to a directory to serialize checkpoints and save job logs."
)
group = parser.add_argument_group("Compute resource management arguments.")
group.add_argument(
"--cpu-workers", type=int, default=0,
help="Number of CPU workers per GPU to use for data loading.",
)
group.add_argument(
"--num-machines", type=int, default=1,
help="Number of machines used in distributed training."
)
group.add_argument(
"--num-gpus-per-machine", type=int, default=0,
help="""Number of GPUs per machine with IDs as (0, 1, 2 ...). Set as
zero for single-process CPU training.""",
)
group.add_argument(
"--machine-rank", type=int, default=0,
help="""Rank of the machine, integer in [0, num_machines). Default 0
for training with a single machine.""",
)
group.add_argument(
"--dist-url", default=f"tcp://127.0.0.1:23456",
help="""URL of the master process in distributed training, it defaults
to localhost for single-machine training.""",
)
# fmt: on
return parser