Vishakaraj's picture
Upload folder using huggingface_hub
c709b60
raw
history blame contribute delete
No virus
2.32 kB
import math
import torch
import torch.distributed as dist
from detectron2.modeling.roi_heads import FastRCNNConvFCHead, MaskRCNNConvUpsampleHead
from detectron2.utils import comm
from fvcore.nn.distributed import differentiable_all_gather
def concat_all_gather(input):
bs_int = input.shape[0]
size_list = comm.all_gather(bs_int)
max_size = max(size_list)
max_shape = (max_size,) + input.shape[1:]
padded_input = input.new_zeros(max_shape)
padded_input[:bs_int] = input
all_inputs = differentiable_all_gather(padded_input)
inputs = [x[:sz] for sz, x in zip(size_list, all_inputs)]
return inputs, size_list
def batch_shuffle(x):
# gather from all gpus
batch_size_this = x.shape[0]
all_xs, batch_size_all = concat_all_gather(x)
all_xs_concat = torch.cat(all_xs, dim=0)
total_bs = sum(batch_size_all)
rank = dist.get_rank()
assert batch_size_all[rank] == batch_size_this
idx_range = (sum(batch_size_all[:rank]), sum(batch_size_all[: rank + 1]))
# random shuffle index
idx_shuffle = torch.randperm(total_bs, device=x.device)
# broadcast to all gpus
dist.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
splits = torch.split(idx_shuffle, math.ceil(total_bs / dist.get_world_size()))
if len(splits) > rank:
idx_this = splits[rank]
else:
idx_this = idx_shuffle.new_zeros([0])
return all_xs_concat[idx_this], idx_unshuffle[idx_range[0] : idx_range[1]]
def batch_unshuffle(x, idx_unshuffle):
all_x, _ = concat_all_gather(x)
x_gather = torch.cat(all_x, dim=0)
return x_gather[idx_unshuffle]
def wrap_shuffle(module_type, method):
def new_method(self, x):
if self.training:
x, idx = batch_shuffle(x)
x = getattr(module_type, method)(self, x)
if self.training:
x = batch_unshuffle(x, idx)
return x
return type(module_type.__name__ + "WithShuffle", (module_type,), {method: new_method})
from .mask_rcnn_BNhead import model, dataloader, lr_multiplier, optimizer, train
model.roi_heads.box_head._target_ = wrap_shuffle(FastRCNNConvFCHead, "forward")
model.roi_heads.mask_head._target_ = wrap_shuffle(MaskRCNNConvUpsampleHead, "layers")