File size: 3,169 Bytes
12bfd03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# ------------------------------------------
# Diffsound
# code based https://github.com/cientgu/VQ-Diffusion
# ------------------------------------------
import distributed.distributed as dist_fn
import torch
from torch import distributed as dist
from torch import multiprocessing as mp

# import distributed as dist_fn


def find_free_port():
    import socket

    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()

    return port


def launch(fn,
           n_gpu_per_machine,
           n_machine=1,
           machine_rank=0,
           dist_url=None,
           args=()):
    world_size = n_machine * n_gpu_per_machine

    if world_size > 1:
        # if "OMP_NUM_THREADS" not in os.environ:
        #     os.environ["OMP_NUM_THREADS"] = "1"
        if dist_url == "auto":
            if n_machine != 1:
                raise ValueError(
                    'dist_url="auto" not supported in multi-machine jobs')
            port = find_free_port()
            dist_url = f"tcp://127.0.0.1:{port}"
        print('dist_url ', dist_url)
        print('n_machine ', n_machine)
        print('args ', args)
        print('world_size ', world_size)
        print('machine_rank ', machine_rank)
        if n_machine > 1 and dist_url.startswith("file://"):
            raise ValueError(
                "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
            )

        mp.spawn(
            distributed_worker,
            nprocs=n_gpu_per_machine,
            args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url,
                  args),
            daemon=False, )
        # n_machine ? world_size
    else:
        local_rank = 0
        fn(local_rank, *args)


def distributed_worker(local_rank, fn, world_size, n_gpu_per_machine,
                       machine_rank, dist_url, args):
    if not torch.cuda.is_available():
        raise OSError("CUDA is not available. Please check your environments")

    global_rank = machine_rank * n_gpu_per_machine + local_rank
    print('local_rank ', local_rank)
    print('global_rank ', global_rank)
    try:
        dist.init_process_group(
            backend="NCCL",
            init_method=dist_url,
            world_size=world_size,
            rank=global_rank, )

    except Exception:
        raise OSError("failed to initialize NCCL groups")

    # changed
    dist_fn.synchronize()

    if n_gpu_per_machine > torch.cuda.device_count():
        raise ValueError(
            f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
        )

    torch.cuda.set_device(local_rank)

    if dist_fn.LOCAL_PROCESS_GROUP is not None:
        raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")

    # change paert

    n_machine = world_size // n_gpu_per_machine
    for i in range(n_machine):
        ranks_on_i = list(
            range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
        pg = dist.new_group(ranks_on_i)

        if i == machine_rank:
            dist_fn.LOCAL_PROCESS_GROUP = pg

    fn(local_rank, *args)