glenn-jocher commited on
Commit
121d90b
1 Parent(s): 6b95d6d

update fuse_conv_and_bn()

Browse files
Files changed (1) hide show
  1. utils/torch_utils.py +10 -13
utils/torch_utils.py CHANGED
@@ -90,7 +90,7 @@ def prune(model, amount=0.3):
90
  import torch.nn.utils.prune as prune
91
  print('Pruning model... ', end='')
92
  for name, m in model.named_modules():
93
- if isinstance(m, torch.nn.Conv2d):
94
  prune.l1_unstructured(m, name='weight', amount=amount) # prune
95
  prune.remove(m, 'weight') # make permanent
96
  print(' %.3g global sparsity' % sparsity(model))
@@ -100,12 +100,12 @@ def fuse_conv_and_bn(conv, bn):
100
  # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
101
  with torch.no_grad():
102
  # init
103
- fusedconv = torch.nn.Conv2d(conv.in_channels,
104
- conv.out_channels,
105
- kernel_size=conv.kernel_size,
106
- stride=conv.stride,
107
- padding=conv.padding,
108
- bias=True)
109
 
110
  # prepare filters
111
  w_conv = conv.weight.clone().view(conv.out_channels, -1)
@@ -113,10 +113,7 @@ def fuse_conv_and_bn(conv, bn):
113
  fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
114
 
115
  # prepare spatial bias
116
- if conv.bias is not None:
117
- b_conv = conv.bias
118
- else:
119
- b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
120
  b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
121
  fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
122
 
@@ -159,8 +156,8 @@ def load_classifier(name='resnet101', n=2):
159
 
160
  # Reshape output to n classes
161
  filters = model.fc.weight.shape[1]
162
- model.fc.bias = torch.nn.Parameter(torch.zeros(n), requires_grad=True)
163
- model.fc.weight = torch.nn.Parameter(torch.zeros(n, filters), requires_grad=True)
164
  model.fc.out_features = n
165
  return model
166
 
 
90
  import torch.nn.utils.prune as prune
91
  print('Pruning model... ', end='')
92
  for name, m in model.named_modules():
93
+ if isinstance(m, nn.Conv2d):
94
  prune.l1_unstructured(m, name='weight', amount=amount) # prune
95
  prune.remove(m, 'weight') # make permanent
96
  print(' %.3g global sparsity' % sparsity(model))
 
100
  # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
101
  with torch.no_grad():
102
  # init
103
+ fusedconv = nn.Conv2d(conv.in_channels,
104
+ conv.out_channels,
105
+ kernel_size=conv.kernel_size,
106
+ stride=conv.stride,
107
+ padding=conv.padding,
108
+ bias=True).to(conv.weight.device)
109
 
110
  # prepare filters
111
  w_conv = conv.weight.clone().view(conv.out_channels, -1)
 
113
  fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
114
 
115
  # prepare spatial bias
116
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
 
 
 
117
  b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
118
  fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
119
 
 
156
 
157
  # Reshape output to n classes
158
  filters = model.fc.weight.shape[1]
159
+ model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
160
+ model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
161
  model.fc.out_features = n
162
  return model
163