File size: 13,914 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
import dataclasses
import os
import socket
from typing import Optional

import torch
import torch.distributed


@dataclasses.dataclass
class DistributedOption:
    # Enable distributed Training
    distributed: bool = False
    # torch.distributed.Backend: "nccl", "mpi", "gloo", or "tcp"
    dist_backend: str = "nccl"
    # if init_method="env://",
    # env values of "MASTER_PORT", "MASTER_ADDR", "WORLD_SIZE", and "RANK" are referred.
    dist_init_method: str = "env://"
    dist_world_size: Optional[int] = None
    dist_rank: Optional[int] = None
    local_rank: Optional[int] = None
    ngpu: int = 0
    dist_master_addr: Optional[str] = None
    dist_master_port: Optional[int] = None
    dist_launcher: Optional[str] = None
    multiprocessing_distributed: bool = True

    def init_options(self):
        if self.distributed:
            if self.dist_init_method == "env://":
                if get_master_addr(self.dist_master_addr, self.dist_launcher) is None:
                    raise RuntimeError(
                        "--dist_master_addr or MASTER_ADDR must be set "
                        "if --dist_init_method == 'env://'"
                    )
                if get_master_port(self.dist_master_port) is None:
                    raise RuntimeError(
                        "--dist_master_port or MASTER_PORT must be set "
                        "if --dist_init_port == 'env://'"
                    )

            # About priority order:
            # If --dist_* is specified:
            #    Use the value of --dist_rank and overwrite it environ just in case.
            # elif environ is set:
            #    Use the value of environ and set it to self
            self.dist_rank = get_rank(self.dist_rank, self.dist_launcher)
            self.dist_world_size = get_world_size(
                self.dist_world_size, self.dist_launcher
            )
            self.local_rank = get_local_rank(self.local_rank, self.dist_launcher)

            if self.local_rank is not None:
                if self.ngpu > 1:
                    raise RuntimeError(f"Assuming 1GPU in this case: ngpu={self.ngpu}")
                if "CUDA_VISIBLE_DEVICES" in os.environ:
                    cvd = os.environ["CUDA_VISIBLE_DEVICES"]
                    if self.local_rank >= len(cvd.split(",")):
                        raise RuntimeError(
                            f"LOCAL_RANK={self.local_rank} is bigger "
                            f"than the number of visible devices: {cvd}"
                        )

            if (
                self.dist_rank is not None
                and self.dist_world_size is not None
                and self.dist_rank >= self.dist_world_size
            ):
                raise RuntimeError(
                    f"RANK >= WORLD_SIZE: {self.dist_rank} >= {self.dist_world_size}"
                )

            if self.dist_init_method == "env://":
                self.dist_master_addr = get_master_addr(
                    self.dist_master_addr, self.dist_launcher
                )
                self.dist_master_port = get_master_port(self.dist_master_port)
                if (
                    self.dist_master_addr is not None
                    and self.dist_master_port is not None
                ):
                    self.dist_init_method = (
                        f"tcp://{self.dist_master_addr}:{self.dist_master_port}"
                    )

    def init_torch_distributed(self):
        if self.distributed:
            # See:
            # https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html
            os.environ.setdefault("NCCL_DEBUG", "INFO")

            # See:
            # https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
            os.environ.setdefault("NCCL_BLOCKING_WAIT", "1")

            torch.distributed.init_process_group(
                backend=self.dist_backend,
                init_method=self.dist_init_method,
                world_size=self.dist_world_size,
                rank=self.dist_rank,
            )

            # About distributed model:
            # if self.local_rank is not None and ngpu == 1
            #    => Distributed with n-Process and n-GPU
            # if self.local_rank is None and ngpu >= 1
            #    => Distributed with 1-Process and n-GPU
            if self.local_rank is not None and self.ngpu > 0:
                torch.cuda.set_device(self.local_rank)


def resolve_distributed_mode(args):
    # Note that args.distributed is set by only this function.
    # and ArgumentParser doesn't have such option

    if args.multiprocessing_distributed:
        num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher)
        # a. multi-node
        if num_nodes > 1:
            args.distributed = True
        # b. single-node and multi-gpu with multiprocessing_distributed mode
        elif args.ngpu > 1:
            args.distributed = True
        # c. single-node and single-gpu
        else:
            args.distributed = False

        if args.ngpu <= 1:
            # Disable multiprocessing_distributed mode if 1process per node or cpu mode
            args.multiprocessing_distributed = False
        if args.ngpu == 1:
            # If the number of GPUs equals to 1 with multiprocessing_distributed mode,
            # LOCAL_RANK is always 0
            args.local_rank = 0

        if num_nodes > 1 and get_node_rank(args.dist_rank, args.dist_launcher) is None:
            raise RuntimeError(
                "--dist_rank or RANK must be set "
                "if --multiprocessing_distributed == true"
            )

        # Note that RANK, LOCAL_RANK, and WORLD_SIZE is automatically set,
        # so we don't need to check here
    else:
        # d. multiprocess and multi-gpu with external launcher
        #    e.g. torch.distributed.launch
        if get_world_size(args.dist_world_size, args.dist_launcher) > 1:
            args.distributed = True
        # e. single-process
        else:
            args.distributed = False

        if args.distributed and args.ngpu > 0:
            if get_local_rank(args.local_rank, args.dist_launcher) is None:
                raise RuntimeError(
                    "--local_rank or LOCAL_RANK must be set "
                    "if --multiprocessing_distributed == false"
                )
        if args.distributed:
            if get_node_rank(args.dist_rank, args.dist_launcher) is None:
                raise RuntimeError(
                    "--dist_rank or RANK must be set "
                    "if --multiprocessing_distributed == false"
                )
    if args.distributed and args.dist_launcher == "slurm" and not is_in_slurm_step():
        raise RuntimeError("Launch by 'srun' command if --dist_launcher='slurm'")


def is_in_slurm_job() -> bool:
    return "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ


def is_in_slurm_step() -> bool:
    return (
        is_in_slurm_job()
        and "SLURM_STEP_NUM_NODES" in os.environ
        and "SLURM_STEP_NODELIST" in os.environ
    )


def _int_or_none(x: Optional[str]) -> Optional[int]:
    if x is None:
        return x
    return int(x)


def free_port():
    """Find free port using bind().

    There are some interval between finding this port and using it
    and the other process might catch the port by that time.
    Thus it is not guaranteed that the port is really empty.

    """
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        sock.bind(("", 0))
        return sock.getsockname()[1]


def get_rank(prior=None, launcher: str = None) -> Optional[int]:
    if prior is None:
        if launcher == "slurm":
            if not is_in_slurm_step():
                raise RuntimeError("This process seems not to be launched by 'srun'")
            prior = os.environ["SLURM_PROCID"]
        elif launcher == "mpi":
            raise RuntimeError(
                "launcher=mpi is used for 'multiprocessing-distributed' mode"
            )
        elif launcher is not None:
            raise RuntimeError(f"launcher='{launcher}' is not supported")

    if prior is not None:
        return int(prior)
    else:
        # prior is None and RANK is None -> RANK = None
        return _int_or_none(os.environ.get("RANK"))


def get_world_size(prior=None, launcher: str = None) -> int:
    if prior is None:
        if launcher == "slurm":
            if not is_in_slurm_step():
                raise RuntimeError("This process seems not to be launched by 'srun'")
            prior = int(os.environ["SLURM_NTASKS"])
        elif launcher == "mpi":
            raise RuntimeError(
                "launcher=mpi is used for 'multiprocessing-distributed' mode"
            )
        elif launcher is not None:
            raise RuntimeError(f"launcher='{launcher}' is not supported")

    if prior is not None:
        return int(prior)
    else:
        # prior is None and WORLD_SIZE is None -> WORLD_SIZE = 1
        return int(os.environ.get("WORLD_SIZE", "1"))


def get_local_rank(prior=None, launcher: str = None) -> Optional[int]:
    # LOCAL_RANK is same as GPU device id

    if prior is None:
        if launcher == "slurm":
            if not is_in_slurm_step():
                raise RuntimeError("This process seems not to be launched by 'srun'")

            prior = int(os.environ["SLURM_LOCALID"])
        elif launcher == "mpi":
            raise RuntimeError(
                "launcher=mpi is used for 'multiprocessing-distributed' mode"
            )
        elif launcher is not None:
            raise RuntimeError(f"launcher='{launcher}' is not supported")

    if prior is not None:
        return int(prior)

    elif "LOCAL_RANK" in os.environ:
        return int(os.environ["LOCAL_RANK"])

    elif "CUDA_VISIBLE_DEVICES" in os.environ:
        # There are two possibility:
        # - "CUDA_VISIBLE_DEVICES" is set to multiple GPU ids. e.g. "0.1,2"
        #   => This intends to specify multiple devices to to be used exactly
        #      and local_rank information is possibly insufficient.
        # - "CUDA_VISIBLE_DEVICES" is set to an id. e.g. "1"
        #   => This could be used for LOCAL_RANK
        cvd = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
        if len(cvd) == 1 and "LOCAL_RANK" not in os.environ:
            # If CUDA_VISIBLE_DEVICES is set and LOCAL_RANK is not set,
            # then use it as LOCAL_RANK.

            # Unset CUDA_VISIBLE_DEVICES
            # because the other device must be visible to communicate
            return int(os.environ.pop("CUDA_VISIBLE_DEVICES"))
        else:
            return None
    else:
        return None


def get_master_addr(prior=None, launcher: str = None) -> Optional[str]:
    if prior is None:
        if launcher == "slurm":
            if not is_in_slurm_step():
                raise RuntimeError("This process seems not to be launched by 'srun'")

            # e.g nodelist = foo[1-10],bar[3-8] or foo4,bar[2-10]
            nodelist = os.environ["SLURM_STEP_NODELIST"]
            prior = nodelist.split(",")[0].split("-")[0].replace("[", "")

    if prior is not None:
        return str(prior)
    else:
        return os.environ.get("MASTER_ADDR")


def get_master_port(prior=None) -> Optional[int]:
    if prior is not None:
        return prior
    else:
        return _int_or_none(os.environ.get("MASTER_PORT"))


def get_node_rank(prior=None, launcher: str = None) -> Optional[int]:
    """Get Node Rank.

    Use for "multiprocessing distributed" mode.
    The initial RANK equals to the Node id in this case and
    the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed.

    """
    if prior is not None:
        return prior
    elif launcher == "slurm":
        if not is_in_slurm_step():
            raise RuntimeError("This process seems not to be launched by 'srun'")

        # Assume ntasks_per_node == 1
        if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]:
            raise RuntimeError(
                "Run with --ntasks_per_node=1 if mutliprocessing_distributed=true"
            )
        return int(os.environ["SLURM_NODEID"])
    elif launcher == "mpi":
        # Use mpi4py only for initialization and not using for communication
        from mpi4py import MPI

        comm = MPI.COMM_WORLD
        # Assume ntasks_per_node == 1 (We can't check whether it is or not)
        return comm.Get_rank()
    elif launcher is not None:
        raise RuntimeError(f"launcher='{launcher}' is not supported")
    else:
        return _int_or_none(os.environ.get("RANK"))


def get_num_nodes(prior=None, launcher: str = None) -> Optional[int]:
    """Get the number of nodes.

    Use for "multiprocessing distributed" mode.
    RANK equals to the Node id in this case and
    the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed.

    """
    if prior is not None:
        return prior
    elif launcher == "slurm":
        if not is_in_slurm_step():
            raise RuntimeError("This process seems not to be launched by 'srun'")

        # Assume ntasks_per_node == 1
        if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]:
            raise RuntimeError(
                "Run with --ntasks_per_node=1 if mutliprocessing_distributed=true"
            )
        return int(os.environ["SLURM_STEP_NUM_NODES"])
    elif launcher == "mpi":
        # Use mpi4py only for initialization and not using for communication
        from mpi4py import MPI

        comm = MPI.COMM_WORLD
        # Assume ntasks_per_node == 1 (We can't check whether it is or not)
        return comm.Get_size()
    elif launcher is not None:
        raise RuntimeError(f"launcher='{launcher}' is not supported")
    else:
        # prior is None -> NUM_NODES = 1
        return int(os.environ.get("WORLD_SIZE", 1))