glenn-jocher commited on
Commit
d133968
1 Parent(s): 1922dde

Clip TTA Augmented Tails (#5028)

Browse files

* Clip TTA Augmented Tails

Experimental TTA update.

* Update yolo.py

* Update yolo.py

* Update yolo.py

* Update yolo.py

Files changed (1) hide show
  1. models/yolo.py +12 -0
models/yolo.py CHANGED
@@ -134,6 +134,7 @@ class Model(nn.Module):
134
  # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
135
  yi = self._descale_pred(yi, fi, si, img_size)
136
  y.append(yi)
 
137
  return torch.cat(y, 1), None # augmented inference, train
138
 
139
  def _forward_once(self, x, profile=False, visualize=False):
@@ -166,6 +167,17 @@ class Model(nn.Module):
166
  p = torch.cat((x, y, wh, p[..., 4:]), -1)
167
  return p
168
 
 
 
 
 
 
 
 
 
 
 
 
169
  def _profile_one_layer(self, m, x, dt):
170
  c = isinstance(m, Detect) # is final layer, copy input as inplace fix
171
  o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
 
134
  # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
135
  yi = self._descale_pred(yi, fi, si, img_size)
136
  y.append(yi)
137
+ y = self._clip_augmented(y) # clip augmented tails
138
  return torch.cat(y, 1), None # augmented inference, train
139
 
140
  def _forward_once(self, x, profile=False, visualize=False):
 
167
  p = torch.cat((x, y, wh, p[..., 4:]), -1)
168
  return p
169
 
170
+ def _clip_augmented(self, y):
171
+ # Clip YOLOv5 augmented inference tails
172
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
173
+ g = sum(4 ** x for x in range(nl)) # grid points
174
+ e = 1 # exclude layer count
175
+ i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
176
+ y[0] = y[0][:, :-i] # large
177
+ i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
178
+ y[-1] = y[-1][:, i:] # small
179
+ return y
180
+
181
  def _profile_one_layer(self, m, x, dt):
182
  c = isinstance(m, Detect) # is final layer, copy input as inplace fix
183
  o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs