| | import torch |
| | from torch import nn |
| |
|
| |
|
| | def fuse_conv_and_bn(conv, bn): |
| | |
| | fusedconv = ( |
| | nn.Conv2d( |
| | conv.in_channels, |
| | conv.out_channels, |
| | kernel_size=conv.kernel_size, |
| | stride=conv.stride, |
| | padding=conv.padding, |
| | groups=conv.groups, |
| | bias=True, |
| | ) |
| | .requires_grad_(False) |
| | .to(conv.weight.device) |
| | ) |
| |
|
| | |
| | w_conv = conv.weight.clone().view(conv.out_channels, -1) |
| | w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) |
| | fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) |
| |
|
| | |
| | b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias |
| | b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) |
| | fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) |
| |
|
| | return fusedconv |
| |
|
| |
|
| | def copy_attr(a, b, include=(), exclude=()): |
| | |
| | for k, v in b.__dict__.items(): |
| | if (include and k not in include) or k.startswith("_") or k in exclude: |
| | continue |
| |
|
| | setattr(a, k, v) |
| |
|