glenn-jocher commited on
Commit
89655a8
1 Parent(s): c4cb785

.fuse() gradient introduction bug fix

Browse files
Files changed (1) hide show
  1. utils/torch_utils.py +22 -22
utils/torch_utils.py CHANGED
@@ -104,28 +104,28 @@ def prune(model, amount=0.3):
104
 
105
 
106
  def fuse_conv_and_bn(conv, bn):
107
- # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
108
- with torch.no_grad():
109
- # init
110
- fusedconv = nn.Conv2d(conv.in_channels,
111
- conv.out_channels,
112
- kernel_size=conv.kernel_size,
113
- stride=conv.stride,
114
- padding=conv.padding,
115
- groups=conv.groups,
116
- bias=True).to(conv.weight.device)
117
-
118
- # prepare filters
119
- w_conv = conv.weight.clone().view(conv.out_channels, -1)
120
- w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
121
- fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
122
-
123
- # prepare spatial bias
124
- b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
125
- b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
126
- fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
127
-
128
- return fusedconv
129
 
130
 
131
  def model_info(model, verbose=False):
 
104
 
105
 
106
  def fuse_conv_and_bn(conv, bn):
107
+ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
108
+
109
+ # init
110
+ fusedconv = nn.Conv2d(conv.in_channels,
111
+ conv.out_channels,
112
+ kernel_size=conv.kernel_size,
113
+ stride=conv.stride,
114
+ padding=conv.padding,
115
+ groups=conv.groups,
116
+ bias=True).requires_grad_(False).to(conv.weight.device)
117
+
118
+ # prepare filters
119
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
120
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
121
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
122
+
123
+ # prepare spatial bias
124
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
125
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
126
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
127
+
128
+ return fusedconv
129
 
130
 
131
  def model_info(model, verbose=False):