File size: 5,710 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import argparse
import os
import random
import sys

from loguru import logger
import numpy as np
import torch

from virtex.config import Config
import virtex.utils.distributed as dist


def cycle(dataloader, device, start_iteration: int = 0):
    r"""
    A generator to yield batches of data from dataloader infinitely.

    Internally, it sets the ``epoch`` for dataloader sampler to shuffle the
    examples. One may optionally provide the starting iteration to make sure
    the shuffling seed is different and continues naturally.
    """
    iteration = start_iteration

    while True:
        if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler):
            # Set the `epoch` of DistributedSampler as current iteration. This
            # is a way of determinisitic shuffling after every epoch, so it is
            # just a seed and need not necessarily be the "epoch".
            logger.info(f"Beginning new epoch, setting shuffle seed {iteration}")
            dataloader.sampler.set_epoch(iteration)

        for batch in dataloader:
            for key in batch:
                batch[key] = batch[key].to(device)
            yield batch
            iteration += 1


def common_setup(_C: Config, _A: argparse.Namespace, job_type: str = "pretrain"):
    r"""
    Setup common stuff at the start of every pretraining or downstream
    evaluation job, all listed here to avoid code duplication. Basic steps:

    1. Fix random seeds and other PyTorch flags.
    2. Set up a serialization directory and loggers.
    3. Log important stuff such as config, process info (useful during
        distributed training).
    4. Save a copy of config to serialization directory.

    .. note::

        It is assumed that multiple processes for distributed training have
        already been launched from outside. Functions from
        :mod:`virtex.utils.distributed` module ae used to get process info.

    Parameters
    ----------
    _C: virtex.config.Config
        Config object with all the parameters.
    _A: argparse.Namespace
        Command line arguments.
    job_type: str, optional (default = "pretrain")
        Type of job for which setup is to be done; one of ``{"pretrain",
        "downstream"}``.
    """

    # Get process rank and world size (assuming distributed is initialized).
    RANK = dist.get_rank()
    WORLD_SIZE = dist.get_world_size()

    # For reproducibility - refer https://pytorch.org/docs/stable/notes/randomness.html
    torch.manual_seed(_C.RANDOM_SEED)
    torch.backends.cudnn.deterministic = _C.CUDNN_DETERMINISTIC
    torch.backends.cudnn.benchmark = _C.CUDNN_BENCHMARK
    random.seed(_C.RANDOM_SEED)
    np.random.seed(_C.RANDOM_SEED)

    # Create serialization directory and save config in it.
    os.makedirs(_A.serialization_dir, exist_ok=True)
    _C.dump(os.path.join(_A.serialization_dir, f"{job_type}_config.yaml"))

    # Remove default logger, create a logger for each process which writes to a
    # separate log-file. This makes changes in global scope.
    logger.remove(0)
    if dist.get_world_size() > 1:
        logger.add(
            os.path.join(_A.serialization_dir, f"log-rank{RANK}.txt"),
            format="{time} {level} {message}",
        )

    # Add a logger for stdout only for the master process.
    if dist.is_master_process():
        logger.add(
            sys.stdout, format="<g>{time}</g>: <lvl>{message}</lvl>", colorize=True
        )

    # Print process info, config and args.
    logger.info(f"Rank of current process: {RANK}. World size: {WORLD_SIZE}")
    logger.info(str(_C))

    logger.info("Command line args:")
    for arg in vars(_A):
        logger.info("{:<20}: {}".format(arg, getattr(_A, arg)))


def common_parser(description: str = "") -> argparse.ArgumentParser:
    r"""
    Create an argument parser some common arguments useful for any pretraining
    or downstream evaluation scripts.

    Parameters
    ----------
    description: str, optional (default = "")
        Description to be used with the argument parser.

    Returns
    -------
    argparse.ArgumentParser
        A parser object with added arguments.
    """
    parser = argparse.ArgumentParser(description=description)

    # fmt: off
    parser.add_argument(
        "--config", metavar="FILE", help="Path to a pretraining config file."
    )
    parser.add_argument(
        "--config-override", nargs="*", default=[],
        help="A list of key-value pairs to modify pretraining config params.",
    )
    parser.add_argument(
        "--serialization-dir", default="/tmp/virtex",
        help="Path to a directory to serialize checkpoints and save job logs."
    )

    group = parser.add_argument_group("Compute resource management arguments.")
    group.add_argument(
        "--cpu-workers", type=int, default=0,
        help="Number of CPU workers per GPU to use for data loading.",
    )
    group.add_argument(
        "--num-machines", type=int, default=1,
        help="Number of machines used in distributed training."
    )
    group.add_argument(
        "--num-gpus-per-machine", type=int, default=0,
        help="""Number of GPUs per machine with IDs as (0, 1, 2 ...). Set as
        zero for single-process CPU training.""",
    )
    group.add_argument(
        "--machine-rank", type=int, default=0,
        help="""Rank of the machine, integer in [0, num_machines). Default 0
        for training with a single machine.""",
    )
    group.add_argument(
        "--dist-url", default=f"tcp://127.0.0.1:23456",
        help="""URL of the master process in distributed training, it defaults
        to localhost for single-machine training.""",
    )
    # fmt: on

    return parser