glenn-jocher commited on
Commit
ce0c58f
1 Parent(s): af41083

update compute_loss()

Browse files
Files changed (1) hide show
  1. utils/utils.py +3 -3
utils/utils.py CHANGED
@@ -437,7 +437,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
437
  BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
438
 
439
  # per output
440
- nt = 0 # targets
 
441
  balance = [1.0, 1.0, 1.0]
442
  for i, pi in enumerate(p): # layer index, layer predictions
443
  b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
@@ -470,7 +471,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
470
 
471
  lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
472
 
473
- s = 3 / (i + 1) # output count scaling
474
  lbox *= h['giou'] * s
475
  lobj *= h['obj'] * s
476
  lcls *= h['cls'] * s
@@ -517,7 +518,6 @@ def build_targets(p, targets, model):
517
  j, k = ((gxy % 1. < g) & (gxy > 1.)).T
518
  a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
519
  offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
520
-
521
  elif style == 'rect4':
522
  j, k = ((gxy % 1. < g) & (gxy > 1.)).T
523
  l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
 
437
  BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
438
 
439
  # per output
440
+ nt = 0 # number of targets
441
+ np = len(p) # number of outputs
442
  balance = [1.0, 1.0, 1.0]
443
  for i, pi in enumerate(p): # layer index, layer predictions
444
  b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
 
471
 
472
  lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
473
 
474
+ s = 3 / np # output count scaling
475
  lbox *= h['giou'] * s
476
  lobj *= h['obj'] * s
477
  lcls *= h['cls'] * s
 
518
  j, k = ((gxy % 1. < g) & (gxy > 1.)).T
519
  a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
520
  offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
 
521
  elif style == 'rect4':
522
  j, k = ((gxy % 1. < g) & (gxy > 1.)).T
523
  l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T