File size: 1,934 Bytes
c24ca9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.


import logging
import os
from datetime import timedelta
from typing import List

import torch
import torch.distributed as dist
import torch.multiprocessing

logger = logging.getLogger(__name__)


def is_dist_initialized() -> bool:
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_rank() -> int:
    if not is_dist_initialized():
        return 0
    return dist.get_rank()


def get_local_rank() -> int:
    if not is_dist_initialized():
        return 0
    return int(os.environ["LOCAL_RANK"])


def get_world_size() -> int:
    if not is_dist_initialized():
        return 1
    return dist.get_world_size()


def is_main_process() -> bool:
    return get_rank() == 0


def init_distributed(loggers: List[logging.Logger]) -> None:
    """Initializes the distributed backend"""
    torch.multiprocessing.set_start_method("spawn")
    if "RANK" not in os.environ:
        logger.error(
            "Cannot init disributed context, as environment varaibles are not set."
        )
        return
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    logger.info(
        f"Rank={rank} local rank={local_rank}, world_size={world_size}, is_master={rank == 0}"
    )
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank,
        timeout=timedelta(seconds=180),
    )
    logger.info(f"Setting cuda:{local_rank} as main device")
    if not is_main_process():
        for to_mute in loggers:
            to_mute.setLevel(logging.ERROR)
    torch.cuda.set_device(local_rank)
    dist.barrier()