Spaces:
Runtime error
Runtime error
Dense_Captioning_-_GRiT
/
detectron2
/projects
/Rethinking-BatchNorm
/configs
/mask_rcnn_BNhead_shuffle.py
import math | |
import torch | |
import torch.distributed as dist | |
from detectron2.modeling.roi_heads import FastRCNNConvFCHead, MaskRCNNConvUpsampleHead | |
from detectron2.utils import comm | |
from fvcore.nn.distributed import differentiable_all_gather | |
def concat_all_gather(input): | |
bs_int = input.shape[0] | |
size_list = comm.all_gather(bs_int) | |
max_size = max(size_list) | |
max_shape = (max_size,) + input.shape[1:] | |
padded_input = input.new_zeros(max_shape) | |
padded_input[:bs_int] = input | |
all_inputs = differentiable_all_gather(padded_input) | |
inputs = [x[:sz] for sz, x in zip(size_list, all_inputs)] | |
return inputs, size_list | |
def batch_shuffle(x): | |
# gather from all gpus | |
batch_size_this = x.shape[0] | |
all_xs, batch_size_all = concat_all_gather(x) | |
all_xs_concat = torch.cat(all_xs, dim=0) | |
total_bs = sum(batch_size_all) | |
rank = dist.get_rank() | |
assert batch_size_all[rank] == batch_size_this | |
idx_range = (sum(batch_size_all[:rank]), sum(batch_size_all[: rank + 1])) | |
# random shuffle index | |
idx_shuffle = torch.randperm(total_bs, device=x.device) | |
# broadcast to all gpus | |
dist.broadcast(idx_shuffle, src=0) | |
# index for restoring | |
idx_unshuffle = torch.argsort(idx_shuffle) | |
# shuffled index for this gpu | |
splits = torch.split(idx_shuffle, math.ceil(total_bs / dist.get_world_size())) | |
if len(splits) > rank: | |
idx_this = splits[rank] | |
else: | |
idx_this = idx_shuffle.new_zeros([0]) | |
return all_xs_concat[idx_this], idx_unshuffle[idx_range[0] : idx_range[1]] | |
def batch_unshuffle(x, idx_unshuffle): | |
all_x, _ = concat_all_gather(x) | |
x_gather = torch.cat(all_x, dim=0) | |
return x_gather[idx_unshuffle] | |
def wrap_shuffle(module_type, method): | |
def new_method(self, x): | |
if self.training: | |
x, idx = batch_shuffle(x) | |
x = getattr(module_type, method)(self, x) | |
if self.training: | |
x = batch_unshuffle(x, idx) | |
return x | |
return type(module_type.__name__ + "WithShuffle", (module_type,), {method: new_method}) | |
from .mask_rcnn_BNhead import model, dataloader, lr_multiplier, optimizer, train | |
model.roi_heads.box_head._target_ = wrap_shuffle(FastRCNNConvFCHead, "forward") | |
model.roi_heads.mask_head._target_ = wrap_shuffle(MaskRCNNConvUpsampleHead, "layers") | |