Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import random | |
import sys | |
from loguru import logger | |
import numpy as np | |
import torch | |
from virtex.config import Config | |
import virtex.utils.distributed as dist | |
def cycle(dataloader, device, start_iteration: int = 0): | |
r""" | |
A generator to yield batches of data from dataloader infinitely. | |
Internally, it sets the ``epoch`` for dataloader sampler to shuffle the | |
examples. One may optionally provide the starting iteration to make sure | |
the shuffling seed is different and continues naturally. | |
""" | |
iteration = start_iteration | |
while True: | |
if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler): | |
# Set the `epoch` of DistributedSampler as current iteration. This | |
# is a way of determinisitic shuffling after every epoch, so it is | |
# just a seed and need not necessarily be the "epoch". | |
logger.info(f"Beginning new epoch, setting shuffle seed {iteration}") | |
dataloader.sampler.set_epoch(iteration) | |
for batch in dataloader: | |
for key in batch: | |
batch[key] = batch[key].to(device) | |
yield batch | |
iteration += 1 | |
def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pretrain"): | |
r""" | |
Setup common stuff at the start of every pretraining or downstream | |
evaluation job, all listed here to avoid code duplication. Basic steps: | |
1. Fix random seeds and other PyTorch flags. | |
2. Set up a serialization directory and loggers. | |
3. Log important stuff such as config, process info (useful during | |
distributed training). | |
4. Save a copy of config to serialization directory. | |
.. note:: | |
It is assumed that multiple processes for distributed training have | |
already been launched from outside. Functions from | |
:mod:`virtex.utils.distributed` module ae used to get process info. | |
Parameters | |
---------- | |
_C: virtex.config.Config | |
Config object with all the parameters. | |
_A: argparse.Namespace | |
Command line arguments. | |
job_type: str, optional (default = "pretrain") | |
Type of job for which setup is to be done; one of ``{"pretrain", | |
"downstream"}``. | |
""" | |
# Get process rank and world size (assuming distributed is initialized). | |
RANK = dist.get_rank() | |
WORLD_SIZE = dist.get_world_size() | |
# For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html | |
torch.manual_seed(_C.RANDOM_SEED) | |
torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC | |
torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK | |
random.seed(_C.RANDOM_SEED) | |
np.random.seed(_C.RANDOM_SEED) | |
# Create serialization directory and save config in it. | |
os.makedirs(_A.serialization_dir, exist_ok=True) | |
_C.dump(os.path.join(_A.serialization_dir, f"{job_type}_config.yaml")) | |
# Remove default logger, create a logger for each process which writes to a | |
# separate log-file. This makes changes in global scope. | |
logger.remove(0) | |
if dist.get_world_size() > 1: | |
logger.add( | |
os.path.join(_A.serialization_dir, f"log-rank{RANK}.txt"), | |
format="{time} {level} {message}", | |
) | |
# Add a logger for stdout only for the master process. | |
if dist.is_master_process(): | |
logger.add( | |
sys.stdout, format="<g>{time}</g>: <lvl>{message}</lvl>", colorize=True | |
) | |
# Print process info, config and args. | |
logger.info(f"Rank of current process: {RANK}. World size: {WORLD_SIZE}") | |
logger.info(str(_C)) | |
logger.info("Command line args:") | |
for arg in vars(_A): | |
logger.info("{:<20}: {}".format(arg, getattr(_A, arg))) | |
def common_parser(description: str = "") -> argparse.ArgumentParser: | |
r""" | |
Create an argument parser some common arguments useful for any pretraining | |
or downstream evaluation scripts. | |
Parameters | |
---------- | |
description: str, optional (default = "") | |
Description to be used with the argument parser. | |
Returns | |
------- | |
argparse.ArgumentParser | |
A parser object with added arguments. | |
""" | |
parser = argparse.ArgumentParser(description=description) | |
# fmt: off | |
parser.add_argument( | |
"--config", metavar="FILE", help="Path to a pretraining config file." | |
) | |
parser.add_argument( | |
"--config-override", nargs="*", default=[], | |
help="A list of key-value pairs to modify pretraining config params.", | |
) | |
parser.add_argument( | |
"--serialization-dir", default="/tmp/virtex", | |
help="Path to a directory to serialize checkpoints and save job logs." | |
) | |
group = parser.add_argument_group("Compute resource management arguments.") | |
group.add_argument( | |
"--cpu-workers", type=int, default=0, | |
help="Number of CPU workers per GPU to use for data loading.", | |
) | |
group.add_argument( | |
"--num-machines", type=int, default=1, | |
help="Number of machines used in distributed training." | |
) | |
group.add_argument( | |
"--num-gpus-per-machine", type=int, default=0, | |
help="""Number of GPUs per machine with IDs as (0, 1, 2 ...). Set as | |
zero for single-process CPU training.""", | |
) | |
group.add_argument( | |
"--machine-rank", type=int, default=0, | |
help="""Rank of the machine, integer in [0, num_machines). Default 0 | |
for training with a single machine.""", | |
) | |
group.add_argument( | |
"--dist-url", default=f"tcp://127.0.0.1:23456", | |
help="""URL of the master process in distributed training, it defaults | |
to localhost for single-machine training.""", | |
) | |
# fmt: on | |
return parser | |