glenn-jocher commited on
Commit
6bfa9c2
1 Parent(s): 2acbe96

GhostConv update (#2082)

Browse files
Files changed (2) hide show
  1. models/experimental.py +1 -1
  2. models/yolo.py +3 -2
models/experimental.py CHANGED
@@ -58,7 +58,7 @@ class GhostConv(nn.Module):
58
 
59
  class GhostBottleneck(nn.Module):
60
  # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
61
- def __init__(self, c1, c2, k, s):
62
  super(GhostBottleneck, self).__init__()
63
  c_ = c2 // 2
64
  self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
 
58
 
59
  class GhostBottleneck(nn.Module):
60
  # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
61
+ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
62
  super(GhostBottleneck, self).__init__()
63
  c_ = c2 // 2
64
  self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
models/yolo.py CHANGED
@@ -8,7 +8,7 @@ sys.path.append('./') # to run '$ python *.py' files in subdirectories
8
  logger = logging.getLogger(__name__)
9
 
10
  from models.common import *
11
- from models.experimental import MixConv2d, CrossConv
12
  from utils.autoanchor import check_anchor_order
13
  from utils.general import make_divisible, check_file, set_logging
14
  from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
@@ -210,7 +210,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
210
  pass
211
 
212
  n = max(round(n * gd), 1) if n > 1 else n # depth gain
213
- if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
 
214
  c1, c2 = ch[f], args[0]
215
 
216
  # Normal
 
8
  logger = logging.getLogger(__name__)
9
 
10
  from models.common import *
11
+ from models.experimental import *
12
  from utils.autoanchor import check_anchor_order
13
  from utils.general import make_divisible, check_file, set_logging
14
  from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
 
210
  pass
211
 
212
  n = max(round(n * gd), 1) if n > 1 else n # depth gain
213
+ if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
214
+ C3]:
215
  c1, c2 = ch[f], args[0]
216
 
217
  # Normal