Arulkumar03's picture
Upload 1389 files
fcd8cdd
raw
history blame contribute delete
No virus
700 Bytes
from torch.nn import BatchNorm2d
from torch.nn import functional as F
class BatchNormBatchStat(BatchNorm2d):
"""
BN that uses batch stat in inference
"""
def forward(self, input):
if self.training:
return super().forward(input)
return F.batch_norm(input, None, None, self.weight, self.bias, True, 1.0, self.eps)
# After training with the base config, it's sufficient to load its model with
# this config only for inference -- because the training-time behavior is identical.
from .mask_rcnn_BNhead import model, dataloader, lr_multiplier, optimizer, train
model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = BatchNormBatchStat