glenn-jocher commited on
Commit
8fe299f
1 Parent(s): c672bef

model fuse

Browse files
Files changed (1) hide show
  1. utils/torch_utils.py +1 -1
utils/torch_utils.py CHANGED
@@ -90,7 +90,7 @@ def fuse_conv_and_bn(conv, bn):
90
  if conv.bias is not None:
91
  b_conv = conv.bias
92
  else:
93
- b_conv = torch.zeros(conv.weight.size(0))
94
  b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
95
  fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
96
 
 
90
  if conv.bias is not None:
91
  b_conv = conv.bias
92
  else:
93
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device)
94
  b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
95
  fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
96