glenn-jocher commited on
Commit
8aa2085
1 Parent(s): f000714

Refactor modules (#7823)

Browse files
Files changed (3) hide show
  1. models/experimental.py +10 -14
  2. models/tf.py +6 -8
  3. models/yolo.py +1 -1
models/experimental.py CHANGED
@@ -78,9 +78,7 @@ class Ensemble(nn.ModuleList):
78
  super().__init__()
79
 
80
  def forward(self, x, augment=False, profile=False, visualize=False):
81
- y = []
82
- for module in self:
83
- y.append(module(x, augment, profile, visualize)[0])
84
  # y = torch.stack(y).max(0)[0] # max ensemble
85
  # y = torch.stack(y).mean(0) # mean ensemble
86
  y = torch.cat(y, 1) # nms ensemble
@@ -102,10 +100,9 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
102
  t = type(m)
103
  if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
104
  m.inplace = inplace # torch 1.7.0 compatibility
105
- if t is Detect:
106
- if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility
107
- delattr(m, 'anchor_grid')
108
- setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
109
  elif t is Conv:
110
  m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility
111
  elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
@@ -113,10 +110,9 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
113
 
114
  if len(model) == 1:
115
  return model[-1] # return model
116
- else:
117
- print(f'Ensemble created with {weights}\n')
118
- for k in 'names', 'nc', 'yaml':
119
- setattr(model, k, getattr(model[0], k))
120
- model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
121
- assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
122
- return model # return ensemble
 
78
  super().__init__()
79
 
80
  def forward(self, x, augment=False, profile=False, visualize=False):
81
+ y = [module(x, augment, profile, visualize)[0] for module in self]
 
 
82
  # y = torch.stack(y).max(0)[0] # max ensemble
83
  # y = torch.stack(y).mean(0) # mean ensemble
84
  y = torch.cat(y, 1) # nms ensemble
 
100
  t = type(m)
101
  if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
102
  m.inplace = inplace # torch 1.7.0 compatibility
103
+ if t is Detect and not isinstance(m.anchor_grid, list):
104
+ delattr(m, 'anchor_grid')
105
+ setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
 
106
  elif t is Conv:
107
  m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility
108
  elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
 
110
 
111
  if len(model) == 1:
112
  return model[-1] # return model
113
+ print(f'Ensemble created with {weights}\n')
114
+ for k in 'names', 'nc', 'yaml':
115
+ setattr(model, k, getattr(model[0], k))
116
+ model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
117
+ assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
118
+ return model # return ensemble
 
models/tf.py CHANGED
@@ -362,7 +362,7 @@ class TFModel:
362
  conf_thres=0.25):
363
  y = [] # outputs
364
  x = inputs
365
- for i, m in enumerate(self.model.layers):
366
  if m.f != -1: # if not from previous layer
367
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
368
 
@@ -377,7 +377,6 @@ class TFModel:
377
  scores = probs * classes
378
  if agnostic_nms:
379
  nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
380
- return nms, x[1]
381
  else:
382
  boxes = tf.expand_dims(boxes, 2)
383
  nms = tf.image.combined_non_max_suppression(boxes,
@@ -387,8 +386,7 @@ class TFModel:
387
  iou_thres,
388
  conf_thres,
389
  clip_boxes=False)
390
- return nms, x[1]
391
-
392
  return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
393
  # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
394
  # xywh = x[..., :4] # x(6300,4) boxes
@@ -444,10 +442,10 @@ class AgnosticNMS(keras.layers.Layer):
444
  def representative_dataset_gen(dataset, ncalib=100):
445
  # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
446
  for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
447
- input = np.transpose(img, [1, 2, 0])
448
- input = np.expand_dims(input, axis=0).astype(np.float32)
449
- input /= 255
450
- yield [input]
451
  if n >= ncalib:
452
  break
453
 
 
362
  conf_thres=0.25):
363
  y = [] # outputs
364
  x = inputs
365
+ for m in self.model.layers:
366
  if m.f != -1: # if not from previous layer
367
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
368
 
 
377
  scores = probs * classes
378
  if agnostic_nms:
379
  nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
 
380
  else:
381
  boxes = tf.expand_dims(boxes, 2)
382
  nms = tf.image.combined_non_max_suppression(boxes,
 
386
  iou_thres,
387
  conf_thres,
388
  clip_boxes=False)
389
+ return nms, x[1]
 
390
  return x[0] # output only first tensor [1,6300,85] = [xywh, conf, class0, class1, ...]
391
  # x = x[0][0] # [x(1,6300,85), ...] to x(6300,85)
392
  # xywh = x[..., :4] # x(6300,4) boxes
 
442
  def representative_dataset_gen(dataset, ncalib=100):
443
  # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
444
  for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
445
+ im = np.transpose(img, [1, 2, 0])
446
+ im = np.expand_dims(im, axis=0).astype(np.float32)
447
+ im /= 255
448
+ yield [im]
449
  if n >= ncalib:
450
  break
451
 
models/yolo.py CHANGED
@@ -197,7 +197,7 @@ class Model(nn.Module):
197
  m(x.copy() if c else x)
198
  dt.append((time_sync() - t) * 100)
199
  if m == self.model[0]:
200
- LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} {'module'}")
201
  LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
202
  if c:
203
  LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
 
197
  m(x.copy() if c else x)
198
  dt.append((time_sync() - t) * 100)
199
  if m == self.model[0]:
200
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
201
  LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
202
  if c:
203
  LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")