""" This file contains primitives for multi-gpu communication. This is useful when doing distributed training. """ import os import pickle import shutil import gc import mmcv import torch import torch.distributed as dist from mmcv.runner import get_dist_info def is_distributed(): return get_world_size() > 1 def get_world_size(): if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def get_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def get_local_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 local_rank = int(os.getenv('LOCAL_RANK', 0)) return local_rank def is_master(): return get_rank() == 0 def is_local_master(): return get_local_rank() == 0 def get_local_proc_group(group_size=8): world_size = get_world_size() if world_size <= group_size or group_size == 1: return None assert world_size % group_size == 0, f'world size ({world_size}) should be evenly divided by group size ({group_size}).' process_groups = getattr(get_local_proc_group, 'process_groups', dict()) if group_size not in process_groups: num_groups = dist.get_world_size() // group_size groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)] process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]}) get_local_proc_group.process_groups = process_groups group_idx = get_rank() // group_size process_groups = get_local_proc_group.process_groups.get(group_size)[group_idx] return process_groups def synchronize(): """ Helper function to synchronize (barrier) among all processes when using distributed training """ if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def all_gather(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ to_device = torch.device("cuda") # to_device = torch.device("cpu") world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to(to_device) # obtain Tensor size of each rank local_size = torch.LongTensor([tensor.numel()]).to(to_device) size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.ByteTensor(size=(max_size,)).to(to_device)) if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device) tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_dict(input_dict, average=True): """ Args: input_dict (dict): all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that process with rank 0 has the averaged results. Returns a dict with the same fields as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.reduce(values, dst=0) if dist.get_rank() == 0 and average: # only main process gets accumulated, so only divide by # world_size in this case values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict def broadcast(data, **kwargs): if get_world_size() == 1: return data data = [data] dist.broadcast_object_list(data, **kwargs) return data[0] def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True): rank, world_size = get_dist_info() if tmpdir is None: tmpdir = './tmp' if rank == 0: mmcv.mkdir_or_exist(tmpdir) synchronize() # dump the part result to the dir mmcv.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl')) synchronize() # collect all parts if collect_by_master and rank != 0: return None else: # load results of all parts from tmp dir results = [] for i in range(world_size): part_file = os.path.join(tmpdir, f'part_{i}.pkl') results.append(mmcv.load(part_file)) if not collect_by_master: synchronize() # remove tmp dir if rank == 0: shutil.rmtree(tmpdir) return results def all_gather_tensor(tensor, group_size=None, group=None): if group_size is None: group_size = get_world_size() if group_size == 1: output = [tensor] else: output = [torch.zeros_like(tensor) for _ in range(group_size)] dist.all_gather(output, tensor, group=group) return output def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None): world_size = get_world_size() if world_size == 1: if not concat: return [feat] return feat num_samples, *feat_dim = feat.size() # padding to max number of samples feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim)) feat_padding[:num_samples] = feat # gather feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size) for r, num in enumerate(num_samples_list): feat_gather[r] = feat_gather[r][:num] if concat: feat_gather = torch.cat(feat_gather) return feat_gather class GatherLayer(torch.autograd.Function): '''Gather tensors from all process, supporting backward propagation. ''' @staticmethod def forward(ctx, input): ctx.save_for_backward(input) num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device) ctx.num_samples_list = all_gather_tensor(num_samples) output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False) return tuple(output) @staticmethod def backward(ctx, *grads): # tuple(output)'s grad input, = ctx.saved_tensors num_samples_list = ctx.num_samples_list rank = get_rank() start, end = sum(num_samples_list[:rank]), sum(num_samples_list[:rank + 1]) grads = torch.cat(grads) if is_distributed(): dist.all_reduce(grads) grad_out = torch.zeros_like(input) grad_out[:] = grads[start:end] return grad_out, None, None class GatherLayerWithGroup(torch.autograd.Function): '''Gather tensors from all process, supporting backward propagation. ''' @staticmethod def forward(ctx, input, group, group_size): ctx.save_for_backward(input) ctx.group_size = group_size output = all_gather_tensor(input, group=group, group_size=group_size) return tuple(output) @staticmethod def backward(ctx, *grads): # tuple(output)'s grad input, = ctx.saved_tensors grads = torch.stack(grads) if is_distributed(): dist.all_reduce(grads) grad_out = torch.zeros_like(input) grad_out[:] = grads[get_rank() % ctx.group_size] return grad_out, None, None def gather_layer_with_group(data, group=None, group_size=None): if group_size is None: group_size = get_world_size() output = GatherLayer.apply(data, group, group_size) return output from typing import Union import math # from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm @torch.no_grad() def clip_grad_norm_( self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 ) -> None: self._lazy_init() self._wait_for_previous_optim_step() assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" self._assert_state(TrainingState_.IDLE) max_norm = float(max_norm) norm_type = float(norm_type) # Computes the max norm for this shard's gradients and sync's across workers local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type] if norm_type == math.inf: total_norm = local_norm dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group) else: total_norm = local_norm ** norm_type dist.all_reduce(total_norm, group=self.process_group) total_norm = total_norm ** (1.0 / norm_type) clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) if clip_coef < 1: # multiply by clip_coef, aka, (max_norm/total_norm). for p in self.params_with_grad: assert p.grad is not None p.grad.detach().mul_(clip_coef.to(p.grad.device)) return total_norm def flush(): gc.collect() torch.cuda.empty_cache()