kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
import argparse
import os
import random
import sys
from loguru import logger
import numpy as np
import torch
from virtex.config import Config
import virtex.utils.distributed as dist
def cycle(dataloader, device, start_iteration: int = 0):
r"""
A generator to yield batches of data from dataloader infinitely.
Internally, it sets the ``epoch`` for dataloader sampler to shuffle the
examples. One may optionally provide the starting iteration to make sure
the shuffling seed is different and continues naturally.
"""
iteration = start_iteration
while True:
if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler):
# Set the `epoch` of DistributedSampler as current iteration. This
# is a way of determinisitic shuffling after every epoch, so it is
# just a seed and need not necessarily be the "epoch".
logger.info(f"Beginning new epoch, setting shuffle seed {iteration}")
dataloader.sampler.set_epoch(iteration)
for batch in dataloader:
for key in batch:
batch[key] = batch[key].to(device)
yield batch
iteration += 1
def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pretrain"):
r"""
Setup common stuff at the start of every pretraining or downstream
evaluation job, all listed here to avoid code duplication. Basic steps:
1. Fix random seeds and other PyTorch flags.
2. Set up a serialization directory and loggers.
3. Log important stuff such as config, process info (useful during
distributed training).
4. Save a copy of config to serialization directory.
.. note::
It is assumed that multiple processes for distributed training have
already been launched from outside. Functions from
:mod:`virtex.utils.distributed` module ae used to get process info.
Parameters
----------
_C: virtex.config.Config
Config object with all the parameters.
_A: argparse.Namespace
Command line arguments.
job_type: str, optional (default = "pretrain")
Type of job for which setup is to be done; one of ``{"pretrain",
"downstream"}``.
"""
# Get process rank and world size (assuming distributed is initialized).
RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()
# For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(_C.RANDOM_SEED)
torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC
torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK
random.seed(_C.RANDOM_SEED)
np.random.seed(_C.RANDOM_SEED)
# Create serialization directory and save config in it.
os.makedirs(_A.serialization_dir, exist_ok=True)
_C.dump(os.path.join(_A.serialization_dir, f"{job_type}_config.yaml"))
# Remove default logger, create a logger for each process which writes to a
# separate log-file. This makes changes in global scope.
logger.remove(0)
if dist.get_world_size() > 1:
logger.add(
os.path.join(_A.serialization_dir, f"log-rank{RANK}.txt"),
format="{time} {level} {message}",
)
# Add a logger for stdout only for the master process.
if dist.is_master_process():
logger.add(
sys.stdout, format="<g>{time}</g>: <lvl>{message}</lvl>", colorize=True
)
# Print process info, config and args.
logger.info(f"Rank of current process: {RANK}. World size: {WORLD_SIZE}")
logger.info(str(_C))
logger.info("Command line args:")
for arg in vars(_A):
logger.info("{:<20}: {}".format(arg, getattr(_A, arg)))
def common_parser(description: str = "") -> argparse.ArgumentParser:
r"""
Create an argument parser some common arguments useful for any pretraining
or downstream evaluation scripts.
Parameters
----------
description: str, optional (default = "")
Description to be used with the argument parser.
Returns
-------
argparse.ArgumentParser
A parser object with added arguments.
"""
parser = argparse.ArgumentParser(description=description)
# fmt: off
parser.add_argument(
"--config", metavar="FILE", help="Path to a pretraining config file."
)
parser.add_argument(
"--config-override", nargs="*", default=[],
help="A list of key-value pairs to modify pretraining config params.",
)
parser.add_argument(
"--serialization-dir", default="/tmp/virtex",
help="Path to a directory to serialize checkpoints and save job logs."
)
group = parser.add_argument_group("Compute resource management arguments.")
group.add_argument(
"--cpu-workers", type=int, default=0,
help="Number of CPU workers per GPU to use for data loading.",
)
group.add_argument(
"--num-machines", type=int, default=1,
help="Number of machines used in distributed training."
)
group.add_argument(
"--num-gpus-per-machine", type=int, default=0,
help="""Number of GPUs per machine with IDs as (0, 1, 2 ...). Set as
zero for single-process CPU training.""",
)
group.add_argument(
"--machine-rank", type=int, default=0,
help="""Rank of the machine, integer in [0, num_machines). Default 0
for training with a single machine.""",
)
group.add_argument(
"--dist-url", default=f"tcp://127.0.0.1:23456",
help="""URL of the master process in distributed training, it defaults
to localhost for single-machine training.""",
)
# fmt: on
return parser