rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
1.16 kB
from typing import List
import torch
from torch import Tensor, nn
from detectron2.modeling.meta_arch.retinanet import RetinaNetHead
def apply_sequential(inputs, modules):
for mod in modules:
if isinstance(mod, (nn.BatchNorm2d, nn.SyncBatchNorm)):
# for BN layer, normalize all inputs together
shapes = [i.shape for i in inputs]
spatial_sizes = [s[2] * s[3] for s in shapes]
x = [i.flatten(2) for i in inputs]
x = torch.cat(x, dim=2).unsqueeze(3)
x = mod(x).split(spatial_sizes, dim=2)
inputs = [i.view(s) for s, i in zip(shapes, x)]
else:
inputs = [mod(i) for i in inputs]
return inputs
class RetinaNetHead_SharedTrainingBN(RetinaNetHead):
def forward(self, features: List[Tensor]):
logits = apply_sequential(features, list(self.cls_subnet) + [self.cls_score])
bbox_reg = apply_sequential(features, list(self.bbox_subnet) + [self.bbox_pred])
return logits, bbox_reg
from .retinanet_SyncBNhead import model, dataloader, lr_multiplier, optimizer, train
model.head._target_ = RetinaNetHead_SharedTrainingBN