glenn-jocher commited on
Commit
4fce009
1 Parent(s): 2f77cf3

model.add_nms() method

Browse files
Files changed (2) hide show
  1. hubconf.py +1 -4
  2. models/yolo.py +10 -1
hubconf.py CHANGED
@@ -37,10 +37,7 @@ def create(name, pretrained, channels, classes):
37
  state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
38
  model.load_state_dict(state_dict, strict=False) # load
39
 
40
- m = NMS()
41
- m.f = -1 # from
42
- m.i = model.model[-1].i + 1 # index
43
- model.model.add_module(name='%s' % m.i, module=m) # add NMS
44
  model.eval()
45
  return model
46
 
 
37
  state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter
38
  model.load_state_dict(state_dict, strict=False) # load
39
 
40
+ model.add_nms() # add NMS module
 
 
 
41
  model.eval()
42
  return model
43
 
models/yolo.py CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
7
  import torch
8
  import torch.nn as nn
9
 
10
- from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat
11
  from models.experimental import MixConv2d, CrossConv, C3
12
  from utils.general import check_anchor_order, make_divisible, check_file, set_logging
13
  from utils.torch_utils import (
@@ -168,6 +168,15 @@ class Model(nn.Module):
168
  self.info()
169
  return self
170
 
 
 
 
 
 
 
 
 
 
171
  def info(self, verbose=False): # print model information
172
  model_info(self, verbose)
173
 
 
7
  import torch
8
  import torch.nn as nn
9
 
10
+ from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS
11
  from models.experimental import MixConv2d, CrossConv, C3
12
  from utils.general import check_anchor_order, make_divisible, check_file, set_logging
13
  from utils.torch_utils import (
 
168
  self.info()
169
  return self
170
 
171
+ def add_nms(self): # fuse model Conv2d() + BatchNorm2d() layers
172
+ if type(self.model[-1]) is not NMS: # if missing NMS
173
+ print('Adding NMS module... ')
174
+ m = NMS() # module
175
+ m.f = -1 # from
176
+ m.i = self.model[-1].i + 1 # index
177
+ self.model.add_module(name='%s' % m.i, module=m) # add
178
+ return self
179
+
180
  def info(self, verbose=False): # print model information
181
  model_info(self, verbose)
182