fffiloni commited on
Commit
8a72396
1 Parent(s): 126e8dc

Create distributed.py

Browse files
Files changed (1) hide show
  1. utils/distributed.py +180 -0
utils/distributed.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import pickle
5
+ import torch.distributed as dist
6
+
7
+
8
+ def init_distributed(opt):
9
+ opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available()
10
+ if 'OMPI_COMM_WORLD_SIZE' not in os.environ:
11
+ # application was started without MPI
12
+ # default to single node with single process
13
+ opt['env_info'] = 'no MPI'
14
+ opt['world_size'] = 1
15
+ opt['local_size'] = 1
16
+ opt['rank'] = 0
17
+ opt['local_rank'] = 0
18
+ opt['master_address'] = '127.0.0.1'
19
+ opt['master_port'] = '8673'
20
+ else:
21
+ # application was started with MPI
22
+ # get MPI parameters
23
+ opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE'])
24
+ opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'])
25
+ opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK'])
26
+ opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
27
+
28
+ # set up device
29
+ if not opt['CUDA']:
30
+ assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend'
31
+ opt['device'] = torch.device("cpu")
32
+ else:
33
+ torch.cuda.set_device(opt['local_rank'])
34
+ opt['device'] = torch.device("cuda", opt['local_rank'])
35
+ return opt
36
+
37
+ def is_main_process():
38
+ rank = 0
39
+ if 'OMPI_COMM_WORLD_SIZE' in os.environ:
40
+ rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
41
+
42
+ return rank == 0
43
+
44
+ def get_world_size():
45
+ if not dist.is_available():
46
+ return 1
47
+ if not dist.is_initialized():
48
+ return 1
49
+ return dist.get_world_size()
50
+
51
+ def get_rank():
52
+ if not dist.is_available():
53
+ return 0
54
+ if not dist.is_initialized():
55
+ return 0
56
+ return dist.get_rank()
57
+
58
+
59
+ def synchronize():
60
+ """
61
+ Helper function to synchronize (barrier) among all processes when
62
+ using distributed training
63
+ """
64
+ if not dist.is_available():
65
+ return
66
+ if not dist.is_initialized():
67
+ return
68
+ world_size = dist.get_world_size()
69
+ rank = dist.get_rank()
70
+ if world_size == 1:
71
+ return
72
+
73
+ def _send_and_wait(r):
74
+ if rank == r:
75
+ tensor = torch.tensor(0, device="cuda")
76
+ else:
77
+ tensor = torch.tensor(1, device="cuda")
78
+ dist.broadcast(tensor, r)
79
+ while tensor.item() == 1:
80
+ time.sleep(1)
81
+
82
+ _send_and_wait(0)
83
+ # now sync on the main process
84
+ _send_and_wait(1)
85
+
86
+
87
+ def all_gather(data):
88
+ """
89
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
90
+ Args:
91
+ data: any picklable object
92
+ Returns:
93
+ list[data]: list of data gathered from each rank
94
+ """
95
+ world_size = get_world_size()
96
+ if world_size == 1:
97
+ return [data]
98
+
99
+ # serialized to a Tensor
100
+ buffer = pickle.dumps(data)
101
+ storage = torch.ByteStorage.from_buffer(buffer)
102
+ tensor = torch.ByteTensor(storage).to("cuda")
103
+
104
+ # obtain Tensor size of each rank
105
+ local_size = torch.IntTensor([tensor.numel()]).to("cuda")
106
+ size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
107
+ dist.all_gather(size_list, local_size)
108
+ size_list = [int(size.item()) for size in size_list]
109
+ max_size = max(size_list)
110
+
111
+ # receiving Tensor from all ranks
112
+ # we pad the tensor because torch all_gather does not support
113
+ # gathering tensors of different shapes
114
+ tensor_list = []
115
+ for _ in size_list:
116
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
117
+ if local_size != max_size:
118
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
119
+ tensor = torch.cat((tensor, padding), dim=0)
120
+ dist.all_gather(tensor_list, tensor)
121
+
122
+ data_list = []
123
+ for size, tensor in zip(size_list, tensor_list):
124
+ buffer = tensor.cpu().numpy().tobytes()[:size]
125
+ data_list.append(pickle.loads(buffer))
126
+
127
+ return data_list
128
+
129
+
130
+ def reduce_dict(input_dict, average=True):
131
+ """
132
+ Args:
133
+ input_dict (dict): all the values will be reduced
134
+ average (bool): whether to do average or sum
135
+ Reduce the values in the dictionary from all processes so that process with rank
136
+ 0 has the averaged results. Returns a dict with the same fields as
137
+ input_dict, after reduction.
138
+ """
139
+ world_size = get_world_size()
140
+ if world_size < 2:
141
+ return input_dict
142
+ with torch.no_grad():
143
+ names = []
144
+ values = []
145
+ # sort the keys so that they are consistent across processes
146
+ for k in sorted(input_dict.keys()):
147
+ names.append(k)
148
+ values.append(input_dict[k])
149
+ values = torch.stack(values, dim=0)
150
+ dist.reduce(values, dst=0)
151
+ if dist.get_rank() == 0 and average:
152
+ # only main process gets accumulated, so only divide by
153
+ # world_size in this case
154
+ values /= world_size
155
+ reduced_dict = {k: v for k, v in zip(names, values)}
156
+ return reduced_dict
157
+
158
+
159
+ def broadcast_data(data):
160
+ if not torch.distributed.is_initialized():
161
+ return data
162
+ rank = dist.get_rank()
163
+ if rank == 0:
164
+ data_tensor = torch.tensor(data + [0], device="cuda")
165
+ else:
166
+ data_tensor = torch.tensor(data + [1], device="cuda")
167
+ torch.distributed.broadcast(data_tensor, 0)
168
+ while data_tensor.cpu().numpy()[-1] == 1:
169
+ time.sleep(1)
170
+
171
+ return data_tensor.cpu().numpy().tolist()[:-1]
172
+
173
+
174
+ def reduce_sum(tensor):
175
+ if get_world_size() <= 1:
176
+ return tensor
177
+
178
+ tensor = tensor.clone()
179
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
180
+ return tensor