Spaces:
Runtime error
Runtime error
File size: 4,528 Bytes
f239efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import os
import torch
import torch.distributed as dist
import logging
logger = logging.getLogger(__name__)
def setup_for_distributed(is_master):
import warnings
builtin_warn = warnings.warn
def warn(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_warn(*args, **kwargs)
# Log warnings only once
warnings.warn = warn
warnings.simplefilter("once", UserWarning)
if not is_master:
logging.disable()
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def is_port_in_use(port):
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
# job started by torch.distributed.launch
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
# local rank on the current node / global rank
local_rank = int(os.environ['SLURM_LOCALID'])
global_rank = int(os.environ['SLURM_PROCID'])
# number of processes / GPUs per node
world_size = int(os.environ["SLURM_NNODES"]) * \
int(os.environ["SLURM_TASKS_PER_NODE"][0])
print(world_size)
args.rank = global_rank
args.gpu = local_rank
args.world_size = world_size
else:
logger.info('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
if "tcp" in args.dist_url: # in slurm, multiple program runs in a single node
dist_port = int(args.dist_url.split(":")[-1])
while is_port_in_use(dist_port):
dist_port += 10
args.dist_url = ":".join(args.dist_url.split(":")[:-1] + [str(dist_port)])
print(args.dist_url)
logger.info('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url))
if "SLURM_JOB_ID" in os.environ:
logger.info(f"SLURM_JOB_ID {os.environ['SLURM_JOB_ID']}")
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
# Copyright (c) Facebook, Inc. and its affiliates.
# copied from https://github.com/facebookresearch/vissl/blob/master/vissl/utils/distributed_gradients.py
class GatherLayer(torch.autograd.Function):
"""
Gather tensors from all workers with support for backward propagation:
This implementation does not cut the gradients as torch.distributed.all_gather does.
"""
@staticmethod
def forward(ctx, x):
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
dist.all_gather(output, x)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
all_gradients = torch.stack(grads)
dist.all_reduce(all_gradients)
return all_gradients[dist.get_rank()]
# copied from megavlt
def gather_tensor_along_batch_with_backward(tensor, dim=0):
world_size = get_world_size()
if world_size < 2:
return tensor
tensor_list = GatherLayer.apply(tensor)
tensor_list = torch.cat(tensor_list, dim=dim)
return tensor_list
@torch.no_grad()
def gather_tensor_along_batch(tensor, dim=0):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
world_size = get_world_size()
if world_size < 2:
return tensor
with torch.no_grad():
tensor_list = []
for _ in range(world_size):
tensor_list.append(torch.zeros_like(tensor))
dist.all_gather(tensor_list, tensor)
tensor_list = torch.cat(tensor_list, dim=dim)
return tensor_list
|