glenn-jocher commited on
Commit
5d4258f
1 Parent(s): 7f9bbf0

Fix MixConv2d() remove shortcut + apply depthwise (#5410)

Browse files
Files changed (3) hide show
  1. models/common.py +1 -1
  2. models/experimental.py +11 -10
  3. utils/torch_utils.py +1 -1
models/common.py CHANGED
@@ -113,7 +113,7 @@ class BottleneckCSP(nn.Module):
113
  self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
114
  self.cv4 = Conv(2 * c_, c2, 1, 1)
115
  self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
116
- self.act = nn.LeakyReLU(0.1, inplace=True)
117
  self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
118
 
119
  def forward(self, x):
 
113
  self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
114
  self.cv4 = Conv(2 * c_, c2, 1, 1)
115
  self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
116
+ self.act = nn.SiLU()
117
  self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
118
 
119
  def forward(self, x):
models/experimental.py CHANGED
@@ -2,7 +2,7 @@
2
  """
3
  Experimental modules
4
  """
5
-
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
@@ -48,26 +48,27 @@ class Sum(nn.Module):
48
 
49
  class MixConv2d(nn.Module):
50
  # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
51
- def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
52
  super().__init__()
53
- groups = len(k)
54
  if equal_ch: # equal c_ per group
55
- i = torch.linspace(0, groups - 1E-6, c2).floor() # c2 indices
56
- c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
57
  else: # equal weight.numel() per group
58
- b = [c2] + [0] * groups
59
- a = np.eye(groups + 1, groups, k=-1)
60
  a -= np.roll(a, 1, axis=1)
61
  a *= np.array(k) ** 2
62
  a[0] = 1
63
  c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
64
 
65
- self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
 
66
  self.bn = nn.BatchNorm2d(c2)
67
- self.act = nn.LeakyReLU(0.1, inplace=True)
68
 
69
  def forward(self, x):
70
- return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
71
 
72
 
73
  class Ensemble(nn.ModuleList):
 
2
  """
3
  Experimental modules
4
  """
5
+ import math
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
 
48
 
49
  class MixConv2d(nn.Module):
50
  # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
51
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
52
  super().__init__()
53
+ n = len(k) # number of convolutions
54
  if equal_ch: # equal c_ per group
55
+ i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
56
+ c_ = [(i == g).sum() for g in range(n)] # intermediate channels
57
  else: # equal weight.numel() per group
58
+ b = [c2] + [0] * n
59
+ a = np.eye(n + 1, n, k=-1)
60
  a -= np.roll(a, 1, axis=1)
61
  a *= np.array(k) ** 2
62
  a[0] = 1
63
  c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
64
 
65
+ self.m = nn.ModuleList(
66
+ [nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
67
  self.bn = nn.BatchNorm2d(c2)
68
+ self.act = nn.SiLU()
69
 
70
  def forward(self, x):
71
+ return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
72
 
73
 
74
  class Ensemble(nn.ModuleList):
utils/torch_utils.py CHANGED
@@ -166,7 +166,7 @@ def initialize_weights(model):
166
  elif t is nn.BatchNorm2d:
167
  m.eps = 1e-3
168
  m.momentum = 0.03
169
- elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6]:
170
  m.inplace = True
171
 
172
 
 
166
  elif t is nn.BatchNorm2d:
167
  m.eps = 1e-3
168
  m.momentum = 0.03
169
+ elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
170
  m.inplace = True
171
 
172