glenn-jocher commited on
Commit
587c4b4
1 Parent(s): 9c6732f

Add `DWConvClass()` (#4274)

Browse files

* Add `DWConvClass()`

* Cleanup

* Cleanup2

Files changed (3) hide show
  1. models/common.py +9 -2
  2. models/experimental.py +1 -1
  3. models/yolo.py +2 -2
models/common.py CHANGED
@@ -30,7 +30,7 @@ def autopad(k, p=None): # kernel, padding
30
 
31
 
32
  def DWConv(c1, c2, k=1, s=1, act=True):
33
- # Depth-wise convolution
34
  return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
35
 
36
 
@@ -45,10 +45,17 @@ class Conv(nn.Module):
45
  def forward(self, x):
46
  return self.act(self.bn(self.conv(x)))
47
 
48
- def fuseforward(self, x):
49
  return self.act(self.conv(x))
50
 
51
 
 
 
 
 
 
 
 
52
  class TransformerLayer(nn.Module):
53
  # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
54
  def __init__(self, c, num_heads):
 
30
 
31
 
32
  def DWConv(c1, c2, k=1, s=1, act=True):
33
+ # Depth-wise convolution function
34
  return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
35
 
36
 
 
45
  def forward(self, x):
46
  return self.act(self.bn(self.conv(x)))
47
 
48
+ def forward_fuse(self, x):
49
  return self.act(self.conv(x))
50
 
51
 
52
+ class DWConvClass(Conv):
53
+ # Depth-wise convolution class
54
+ def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
55
+ super().__init__(c1, c2, k, s, act)
56
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False)
57
+
58
+
59
  class TransformerLayer(nn.Module):
60
  # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
61
  def __init__(self, c, num_heads):
models/experimental.py CHANGED
@@ -72,7 +72,7 @@ class GhostBottleneck(nn.Module):
72
 
73
 
74
  class MixConv2d(nn.Module):
75
- # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
76
  def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
77
  super().__init__()
78
  groups = len(k)
 
72
 
73
 
74
  class MixConv2d(nn.Module):
75
+ # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
76
  def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
77
  super().__init__()
78
  groups = len(k)
models/yolo.py CHANGED
@@ -202,10 +202,10 @@ class Model(nn.Module):
202
  def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
203
  LOGGER.info('Fusing layers... ')
204
  for m in self.model.modules():
205
- if type(m) is Conv and hasattr(m, 'bn'):
206
  m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
207
  delattr(m, 'bn') # remove batchnorm
208
- m.forward = m.fuseforward # update forward
209
  self.info()
210
  return self
211
 
 
202
  def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
203
  LOGGER.info('Fusing layers... ')
204
  for m in self.model.modules():
205
+ if isinstance(m, (Conv, DWConvClass)) and hasattr(m, 'bn'):
206
  m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
207
  delattr(m, 'bn') # remove batchnorm
208
+ m.forward = m.forward_fuse # update forward
209
  self.info()
210
  return self
211