OFA-OCR / models /ofa /frozen_bn.py
JustinLin610's picture
first commit
history blame
3.48 kB
# Modified from detectron2: https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py#L13
import torch
from torch import nn
from torch.nn import functional as F
class FrozenBatchNorm2d(nn.Module):
BatchNorm2d where the batch statistics and the affine parameters are fixed.
It contains non-trainable buffers called
"weight" and "bias", "running_mean", "running_var",
initialized to perform identity transformation.
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
which are computed from the original four parameters of BN.
The affine transform `x * weight + bias` will perform the equivalent
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
When loading a backbone model from Caffe2, "running_mean" and "running_var"
will be left unchanged as identity transformation.
Other pre-trained backbone models may contain all 4 parameters.
The forward is implemented by `F.batch_norm(..., training=False)`.
def __init__(self, num_features, eps=1e-5):
self.num_features = num_features
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features) - eps)
def forward(self, x):
if x.requires_grad:
# When gradients are needed, F.batch_norm will use extra memory
# because its backward op computes gradients for weight/bias as well.
scale = self.weight * (self.running_var + self.eps).rsqrt()
bias = self.bias - self.running_mean * scale
scale = scale.reshape(1, -1, 1, 1)
bias = bias.reshape(1, -1, 1, 1)
out_dtype = x.dtype # may be half
return x * scale.to(out_dtype) + bias.to(out_dtype)
# When gradients are not needed, F.batch_norm is a single fused op
# and provide more optimization opportunities.
return F.batch_norm(
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
version = local_metadata.get("version", None)
if version is None or version < 2:
# No running_mean/var in early versions
# This will silent the warnings
if prefix + "running_mean" not in state_dict:
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
if prefix + "running_var" not in state_dict:
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
def __repr__(self):
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)