glenn-jocher commited on
Commit
04bdbe4
1 Parent(s): 5ba1de0

fuse update

Browse files
Files changed (3) hide show
  1. detect.py +3 -6
  2. models/yolo.py +2 -2
  3. test.py +1 -1
detect.py CHANGED
@@ -21,13 +21,10 @@ def detect(save_img=False):
21
 
22
  # Load model
23
  google_utils.attempt_download(weights)
24
- model = torch.load(weights, map_location=device)['model'].float() # load to FP32
25
- # torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
26
- # model.fuse()
27
- model.to(device).eval()
28
- imgsz = check_img_size(imgsz, s=model.model[-1].stride.max()) # check img_size
29
  if half:
30
- model.half() # to FP16
31
 
32
  # Second-stage classifier
33
  classify = False
 
21
 
22
  # Load model
23
  google_utils.attempt_download(weights)
24
+ model = torch.load(weights, map_location=device)['model'].float().eval() # load FP32 model
25
+ imgsz = check_img_size(imgsz, s=model.stride.max()) # check img_size
 
 
 
26
  if half:
27
+ model.float() # to FP16
28
 
29
  # Second-stage classifier
30
  classify = False
models/yolo.py CHANGED
@@ -142,14 +142,14 @@ class Model(nn.Module):
142
  # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
143
 
144
  def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
145
- print('Fusing layers...')
146
  for m in self.model.modules():
147
  if type(m) is Conv:
148
  m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
149
  m.bn = None # remove batchnorm
150
  m.forward = m.fuseforward # update forward
151
  torch_utils.model_info(self)
152
-
153
 
154
  def parse_model(md, ch): # model_dict, input_channels(3)
155
  print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
 
142
  # print('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
143
 
144
  def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
145
+ print('Fusing layers... ', end='')
146
  for m in self.model.modules():
147
  if type(m) is Conv:
148
  m.conv = torch_utils.fuse_conv_and_bn(m.conv, m.bn) # update conv
149
  m.bn = None # remove batchnorm
150
  m.forward = m.fuseforward # update forward
151
  torch_utils.model_info(self)
152
+ return self
153
 
154
  def parse_model(md, ch): # model_dict, input_channels(3)
155
  print('\n%3s%18s%3s%10s %-40s%-30s' % ('', 'from', 'n', 'params', 'module', 'arguments'))
test.py CHANGED
@@ -22,6 +22,7 @@ def test(data,
22
  # Initialize/load model and set device
23
  if model is None:
24
  training = False
 
25
  device = torch_utils.select_device(opt.device, batch_size=batch_size)
26
 
27
  # Remove previous
@@ -59,7 +60,6 @@ def test(data,
59
 
60
  # Dataloader
61
  if dataloader is None: # not training
62
- merge = opt.merge # use Merge NMS
63
  img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
64
  _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
65
  path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images
 
22
  # Initialize/load model and set device
23
  if model is None:
24
  training = False
25
+ merge = opt.merge # use Merge NMS
26
  device = torch_utils.select_device(opt.device, batch_size=batch_size)
27
 
28
  # Remove previous
 
60
 
61
  # Dataloader
62
  if dataloader is None: # not training
 
63
  img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
64
  _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
65
  path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images