Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torch.nn as nn | |
def _fuse_conv_bn(conv: nn.Module, bn: nn.Module) -> nn.Module: | |
"""Fuse conv and bn into one module. | |
Args: | |
conv (nn.Module): Conv to be fused. | |
bn (nn.Module): BN to be fused. | |
Returns: | |
nn.Module: Fused module. | |
""" | |
conv_w = conv.weight | |
conv_b = conv.bias if conv.bias is not None else torch.zeros_like( | |
bn.running_mean) | |
factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) | |
conv.weight = nn.Parameter(conv_w * | |
factor.reshape([conv.out_channels, 1, 1, 1])) | |
conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) | |
return conv | |
def fuse_conv_bn(module: nn.Module) -> nn.Module: | |
"""Recursively fuse conv and bn in a module. | |
During inference, the functionary of batch norm layers is turned off | |
but only the mean and var alone channels are used, which exposes the | |
chance to fuse it with the preceding conv layers to save computations and | |
simplify network structures. | |
Args: | |
module (nn.Module): Module to be fused. | |
Returns: | |
nn.Module: Fused module. | |
""" | |
last_conv = None | |
last_conv_name = None | |
for name, child in module.named_children(): | |
if isinstance(child, | |
(nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)): | |
if last_conv is None: # only fuse BN that is after Conv | |
continue | |
fused_conv = _fuse_conv_bn(last_conv, child) | |
module._modules[last_conv_name] = fused_conv | |
# To reduce changes, set BN as Identity instead of deleting it. | |
module._modules[name] = nn.Identity() | |
last_conv = None | |
elif isinstance(child, nn.Conv2d): | |
last_conv = child | |
last_conv_name = name | |
else: | |
fuse_conv_bn(child) | |
return module | |