glenn-jocher
commited on
Commit
•
121d90b
1
Parent(s):
6b95d6d
update fuse_conv_and_bn()
Browse files- utils/torch_utils.py +10 -13
utils/torch_utils.py
CHANGED
@@ -90,7 +90,7 @@ def prune(model, amount=0.3):
|
|
90 |
import torch.nn.utils.prune as prune
|
91 |
print('Pruning model... ', end='')
|
92 |
for name, m in model.named_modules():
|
93 |
-
if isinstance(m,
|
94 |
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
95 |
prune.remove(m, 'weight') # make permanent
|
96 |
print(' %.3g global sparsity' % sparsity(model))
|
@@ -100,12 +100,12 @@ def fuse_conv_and_bn(conv, bn):
|
|
100 |
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
101 |
with torch.no_grad():
|
102 |
# init
|
103 |
-
fusedconv =
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
|
110 |
# prepare filters
|
111 |
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
@@ -113,10 +113,7 @@ def fuse_conv_and_bn(conv, bn):
|
|
113 |
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
|
114 |
|
115 |
# prepare spatial bias
|
116 |
-
if conv.bias is
|
117 |
-
b_conv = conv.bias
|
118 |
-
else:
|
119 |
-
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
|
120 |
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
121 |
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
122 |
|
@@ -159,8 +156,8 @@ def load_classifier(name='resnet101', n=2):
|
|
159 |
|
160 |
# Reshape output to n classes
|
161 |
filters = model.fc.weight.shape[1]
|
162 |
-
model.fc.bias =
|
163 |
-
model.fc.weight =
|
164 |
model.fc.out_features = n
|
165 |
return model
|
166 |
|
|
|
90 |
import torch.nn.utils.prune as prune
|
91 |
print('Pruning model... ', end='')
|
92 |
for name, m in model.named_modules():
|
93 |
+
if isinstance(m, nn.Conv2d):
|
94 |
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
95 |
prune.remove(m, 'weight') # make permanent
|
96 |
print(' %.3g global sparsity' % sparsity(model))
|
|
|
100 |
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
101 |
with torch.no_grad():
|
102 |
# init
|
103 |
+
fusedconv = nn.Conv2d(conv.in_channels,
|
104 |
+
conv.out_channels,
|
105 |
+
kernel_size=conv.kernel_size,
|
106 |
+
stride=conv.stride,
|
107 |
+
padding=conv.padding,
|
108 |
+
bias=True).to(conv.weight.device)
|
109 |
|
110 |
# prepare filters
|
111 |
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
|
|
113 |
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
|
114 |
|
115 |
# prepare spatial bias
|
116 |
+
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
|
|
|
|
|
|
117 |
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
118 |
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
119 |
|
|
|
156 |
|
157 |
# Reshape output to n classes
|
158 |
filters = model.fc.weight.shape[1]
|
159 |
+
model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
|
160 |
+
model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
|
161 |
model.fc.out_features = n
|
162 |
return model
|
163 |
|