glenn-jocher commited on
Commit
406ee52
1 Parent(s): aa542ce

Loss and IoU speed improvements (#7361)

Browse files

* Loss speed improvements

* bbox_iou speed improvements

* bbox_ioa speed improvements

* box_iou speed improvements

* box_iou speed improvements

Files changed (3) hide show
  1. utils/loss.py +4 -4
  2. utils/metrics.py +26 -28
  3. val.py +2 -2
utils/loss.py CHANGED
@@ -138,7 +138,7 @@ class ComputeLoss:
138
  pxy = pxy.sigmoid() * 2 - 0.5
139
  pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
140
  pbox = torch.cat((pxy, pwh), 1) # predicted box
141
- iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
142
  lbox += (1.0 - iou).mean() # iou loss
143
 
144
  # Objectness
@@ -180,7 +180,7 @@ class ComputeLoss:
180
  tcls, tbox, indices, anch = [], [], [], []
181
  gain = torch.ones(7, device=self.device) # normalized to gridspace gain
182
  ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
183
- targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
184
 
185
  g = 0.5 # bias
186
  off = torch.tensor(
@@ -199,10 +199,10 @@ class ComputeLoss:
199
  gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
200
 
201
  # Match targets to anchors
202
- t = targets * gain
203
  if nt:
204
  # Matches
205
- r = t[:, :, 4:6] / anchors[:, None] # wh ratio
206
  j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t'] # compare
207
  # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
208
  t = t[j] # filter
 
138
  pxy = pxy.sigmoid() * 2 - 0.5
139
  pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
140
  pbox = torch.cat((pxy, pwh), 1) # predicted box
141
+ iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
142
  lbox += (1.0 - iou).mean() # iou loss
143
 
144
  # Objectness
 
180
  tcls, tbox, indices, anch = [], [], [], []
181
  gain = torch.ones(7, device=self.device) # normalized to gridspace gain
182
  ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
183
+ targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices
184
 
185
  g = 0.5 # bias
186
  off = torch.tensor(
 
199
  gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
200
 
201
  # Match targets to anchors
202
+ t = targets * gain # shape(3,n,7)
203
  if nt:
204
  # Matches
205
+ r = t[..., 4:6] / anchors[:, None] # wh ratio
206
  j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t'] # compare
207
  # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
208
  t = t[j] # filter
utils/metrics.py CHANGED
@@ -206,37 +206,36 @@ class ConfusionMatrix:
206
  print(' '.join(map(str, self.matrix[i])))
207
 
208
 
209
- def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
210
- # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
211
- box2 = box2.T
212
 
213
  # Get the coordinates of bounding boxes
214
- if x1y1x2y2: # x1, y1, x2, y2 = box1
215
- b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
216
- b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
217
- else: # transform from xywh to xyxy
218
- b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
219
- b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
220
- b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
221
- b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
 
 
222
 
223
  # Intersection area
224
  inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
225
  (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
226
 
227
  # Union Area
228
- w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
229
- w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
230
  union = w1 * h1 + w2 * h2 - inter + eps
231
 
 
232
  iou = inter / union
233
  if CIoU or DIoU or GIoU:
234
  cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
235
  ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
236
  if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
237
  c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
238
- rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
239
- (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
240
  if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
241
  v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
242
  with torch.no_grad():
@@ -248,6 +247,11 @@ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=
248
  return iou # IoU
249
 
250
 
 
 
 
 
 
251
  def box_iou(box1, box2):
252
  # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
253
  """
@@ -261,16 +265,12 @@ def box_iou(box1, box2):
261
  IoU values for every element in boxes1 and boxes2
262
  """
263
 
264
- def box_area(box):
265
- # box = 4xn
266
- return (box[2] - box[0]) * (box[3] - box[1])
267
-
268
- area1 = box_area(box1.T)
269
- area2 = box_area(box2.T)
270
-
271
  # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
272
- inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
273
- return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
 
 
 
274
 
275
 
276
  def bbox_ioa(box1, box2, eps=1E-7):
@@ -280,11 +280,9 @@ def bbox_ioa(box1, box2, eps=1E-7):
280
  returns: np.array of shape(n)
281
  """
282
 
283
- box2 = box2.transpose()
284
-
285
  # Get the coordinates of bounding boxes
286
- b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
287
- b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
288
 
289
  # Intersection area
290
  inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
 
206
  print(' '.join(map(str, self.matrix[i])))
207
 
208
 
209
+ def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
210
+ # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
 
211
 
212
  # Get the coordinates of bounding boxes
213
+ if xywh: # transform from xywh to xyxy
214
+ (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, 1), box2.chunk(4, 1)
215
+ w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
216
+ b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
217
+ b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
218
+ else: # x1, y1, x2, y2 = box1
219
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, 1)
220
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, 1)
221
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
222
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
223
 
224
  # Intersection area
225
  inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
226
  (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
227
 
228
  # Union Area
 
 
229
  union = w1 * h1 + w2 * h2 - inter + eps
230
 
231
+ # IoU
232
  iou = inter / union
233
  if CIoU or DIoU or GIoU:
234
  cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
235
  ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
236
  if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
237
  c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
238
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
 
239
  if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
240
  v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
241
  with torch.no_grad():
 
247
  return iou # IoU
248
 
249
 
250
+ def box_area(box):
251
+ # box = xyxy(4,n)
252
+ return (box[2] - box[0]) * (box[3] - box[1])
253
+
254
+
255
  def box_iou(box1, box2):
256
  # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
257
  """
 
265
  IoU values for every element in boxes1 and boxes2
266
  """
267
 
 
 
 
 
 
 
 
268
  # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
269
+ (a1, a2), (b1, b2) = box1[:, None].chunk(2, 2), box2.chunk(2, 1)
270
+ inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
271
+
272
+ # IoU = inter / (area1 + area2 - inter)
273
+ return inter / (box_area(box1.T)[:, None] + box_area(box2.T) - inter)
274
 
275
 
276
  def bbox_ioa(box1, box2, eps=1E-7):
 
280
  returns: np.array of shape(n)
281
  """
282
 
 
 
283
  # Get the coordinates of bounding boxes
284
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1
285
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
286
 
287
  # Intersection area
288
  inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
val.py CHANGED
@@ -38,10 +38,10 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
38
  from models.common import DetectMultiBackend
39
  from utils.callbacks import Callbacks
40
  from utils.datasets import create_dataloader
41
- from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_yaml,
42
  coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
43
  scale_coords, xywh2xyxy, xyxy2xywh)
44
- from utils.metrics import ConfusionMatrix, ap_per_class
45
  from utils.plots import output_to_target, plot_images, plot_val_study
46
  from utils.torch_utils import select_device, time_sync
47
 
 
38
  from models.common import DetectMultiBackend
39
  from utils.callbacks import Callbacks
40
  from utils.datasets import create_dataloader
41
+ from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_yaml,
42
  coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
43
  scale_coords, xywh2xyxy, xyxy2xywh)
44
+ from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
45
  from utils.plots import output_to_target, plot_images, plot_val_study
46
  from utils.torch_utils import select_device, time_sync
47