Wheat_HEAD_Detection_Counting_ComputerVision_Model
/
detectron2
/projects
/Rethinking-BatchNorm
/configs
/mask_rcnn_BNhead_batch_stats.py
from torch.nn import BatchNorm2d | |
from torch.nn import functional as F | |
class BatchNormBatchStat(BatchNorm2d): | |
""" | |
BN that uses batch stat in inference | |
""" | |
def forward(self, input): | |
if self.training: | |
return super().forward(input) | |
return F.batch_norm(input, None, None, self.weight, self.bias, True, 1.0, self.eps) | |
# After training with the base config, it's sufficient to load its model with | |
# this config only for inference -- because the training-time behavior is identical. | |
from .mask_rcnn_BNhead import model, dataloader, lr_multiplier, optimizer, train | |
model.roi_heads.box_head.conv_norm = model.roi_heads.mask_head.conv_norm = BatchNormBatchStat | |