|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def _fuse_conv_bn(conv, bn): |
|
"""Fuse conv and bn into one module. |
|
|
|
Args: |
|
conv (nn.Module): Conv to be fused. |
|
bn (nn.Module): BN to be fused. |
|
|
|
Returns: |
|
nn.Module: Fused module. |
|
""" |
|
conv_w = conv.weight |
|
conv_b = conv.bias if conv.bias is not None else torch.zeros_like( |
|
bn.running_mean) |
|
|
|
factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) |
|
conv.weight = nn.Parameter(conv_w * |
|
factor.reshape([conv.out_channels, 1, 1, 1])) |
|
conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) |
|
return conv |
|
|
|
|
|
def fuse_conv_bn(module): |
|
"""Recursively fuse conv and bn in a module. |
|
|
|
During inference, the functionary of batch norm layers is turned off |
|
but only the mean and var alone channels are used, which exposes the |
|
chance to fuse it with the preceding conv layers to save computations and |
|
simplify network structures. |
|
|
|
Args: |
|
module (nn.Module): Module to be fused. |
|
|
|
Returns: |
|
nn.Module: Fused module. |
|
""" |
|
last_conv = None |
|
last_conv_name = None |
|
|
|
for name, child in module.named_children(): |
|
if isinstance(child, |
|
(nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)): |
|
if last_conv is None: |
|
continue |
|
fused_conv = _fuse_conv_bn(last_conv, child) |
|
module._modules[last_conv_name] = fused_conv |
|
|
|
module._modules[name] = nn.Identity() |
|
last_conv = None |
|
elif isinstance(child, nn.Conv2d): |
|
last_conv = child |
|
last_conv_name = name |
|
else: |
|
fuse_conv_bn(child) |
|
return module |
|
|