Yonghye Kwon commited on
Commit
e96c74b
1 Parent(s): f409d8e

Simpler code for DWConvClass (#4310)

Browse files

* more simpler code for DWConvClass

more simpler code for DWConvClass

* remove DWConv function

* Replace DWConvClass with DWConv

Files changed (2) hide show
  1. models/common.py +2 -8
  2. models/yolo.py +1 -1
models/common.py CHANGED
@@ -29,11 +29,6 @@ def autopad(k, p=None): # kernel, padding
29
  return p
30
 
31
 
32
- def DWConv(c1, c2, k=1, s=1, act=True):
33
- # Depth-wise convolution function
34
- return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
35
-
36
-
37
  class Conv(nn.Module):
38
  # Standard convolution
39
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
@@ -49,11 +44,10 @@ class Conv(nn.Module):
49
  return self.act(self.conv(x))
50
 
51
 
52
- class DWConvClass(Conv):
53
  # Depth-wise convolution class
54
  def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
55
- super().__init__(c1, c2, k, s, act)
56
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False)
57
 
58
 
59
  class TransformerLayer(nn.Module):
 
29
  return p
30
 
31
 
 
 
 
 
 
32
  class Conv(nn.Module):
33
  # Standard convolution
34
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
 
44
  return self.act(self.conv(x))
45
 
46
 
47
+ class DWConv(Conv):
48
  # Depth-wise convolution class
49
  def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
50
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
 
51
 
52
 
53
  class TransformerLayer(nn.Module):
models/yolo.py CHANGED
@@ -202,7 +202,7 @@ class Model(nn.Module):
202
  def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
203
  LOGGER.info('Fusing layers... ')
204
  for m in self.model.modules():
205
- if isinstance(m, (Conv, DWConvClass)) and hasattr(m, 'bn'):
206
  m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
207
  delattr(m, 'bn') # remove batchnorm
208
  m.forward = m.forward_fuse # update forward
 
202
  def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
203
  LOGGER.info('Fusing layers... ')
204
  for m in self.model.modules():
205
+ if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
206
  m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
207
  delattr(m, 'bn') # remove batchnorm
208
  m.forward = m.forward_fuse # update forward