glenn-jocher commited on
Commit
01cdb76
1 Parent(s): 24bea5e

Add `SPPF()` layer (#4420)

Browse files

* Add `SPPF()` layer

* Cleanup

* Add credit

Files changed (2) hide show
  1. models/common.py +19 -1
  2. models/yolo.py +6 -4
models/common.py CHANGED
@@ -161,7 +161,7 @@ class C3Ghost(C3):
161
 
162
 
163
  class SPP(nn.Module):
164
- # Spatial pyramid pooling layer used in YOLOv3-SPP
165
  def __init__(self, c1, c2, k=(5, 9, 13)):
166
  super().__init__()
167
  c_ = c1 // 2 # hidden channels
@@ -176,6 +176,24 @@ class SPP(nn.Module):
176
  return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
177
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  class Focus(nn.Module):
180
  # Focus wh information into c-space
181
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
 
161
 
162
 
163
  class SPP(nn.Module):
164
+ # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
165
  def __init__(self, c1, c2, k=(5, 9, 13)):
166
  super().__init__()
167
  c_ = c1 // 2 # hidden channels
 
176
  return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
177
 
178
 
179
+ class SPPF(nn.Module):
180
+ # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
181
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
182
+ super().__init__()
183
+ c_ = c1 // 2 # hidden channels
184
+ self.cv1 = Conv(c1, c_, 1, 1)
185
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
186
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
187
+
188
+ def forward(self, x):
189
+ x = self.cv1(x)
190
+ with warnings.catch_warnings():
191
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
192
+ y1 = self.m(x)
193
+ y2 = self.m(y1)
194
+ return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
195
+
196
+
197
  class Focus(nn.Module):
198
  # Focus wh information into c-space
199
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
models/yolo.py CHANGED
@@ -237,8 +237,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
237
  pass
238
 
239
  n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
240
- if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
241
- C3, C3TR, C3SPP, C3Ghost]:
242
  c1, c2 = ch[f], args[0]
243
  if c2 != no: # if not output
244
  c2 = make_divisible(c2 * gw, 8)
@@ -279,6 +279,7 @@ if __name__ == '__main__':
279
  parser = argparse.ArgumentParser()
280
  parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
281
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
 
282
  opt = parser.parse_args()
283
  opt.cfg = check_file(opt.cfg) # check file
284
  set_logging()
@@ -289,8 +290,9 @@ if __name__ == '__main__':
289
  model.train()
290
 
291
  # Profile
292
- # img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 320, 320).to(device)
293
- # y = model(img, profile=True)
 
294
 
295
  # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
296
  # from torch.utils.tensorboard import SummaryWriter
 
237
  pass
238
 
239
  n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
240
+ if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
241
+ BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]:
242
  c1, c2 = ch[f], args[0]
243
  if c2 != no: # if not output
244
  c2 = make_divisible(c2 * gw, 8)
 
279
  parser = argparse.ArgumentParser()
280
  parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
281
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
282
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
283
  opt = parser.parse_args()
284
  opt.cfg = check_file(opt.cfg) # check file
285
  set_logging()
 
290
  model.train()
291
 
292
  # Profile
293
+ if opt.profile:
294
+ img = torch.rand(8 if torch.cuda.is_available() else 1, 3, 640, 640).to(device)
295
+ y = model(img, profile=True)
296
 
297
  # Tensorboard (not working https://github.com/ultralytics/yolov5/issues/2898)
298
  # from torch.utils.tensorboard import SummaryWriter