glenn-jocher
commited on
Commit
•
ce0c58f
1
Parent(s):
af41083
update compute_loss()
Browse files- utils/utils.py +3 -3
utils/utils.py
CHANGED
@@ -437,7 +437,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|
437 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
438 |
|
439 |
# per output
|
440 |
-
nt = 0 # targets
|
|
|
441 |
balance = [1.0, 1.0, 1.0]
|
442 |
for i, pi in enumerate(p): # layer index, layer predictions
|
443 |
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
@@ -470,7 +471,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|
470 |
|
471 |
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
|
472 |
|
473 |
-
s = 3 /
|
474 |
lbox *= h['giou'] * s
|
475 |
lobj *= h['obj'] * s
|
476 |
lcls *= h['cls'] * s
|
@@ -517,7 +518,6 @@ def build_targets(p, targets, model):
|
|
517 |
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
518 |
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
|
519 |
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
|
520 |
-
|
521 |
elif style == 'rect4':
|
522 |
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
523 |
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
|
|
|
437 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
438 |
|
439 |
# per output
|
440 |
+
nt = 0 # number of targets
|
441 |
+
np = len(p) # number of outputs
|
442 |
balance = [1.0, 1.0, 1.0]
|
443 |
for i, pi in enumerate(p): # layer index, layer predictions
|
444 |
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
|
|
471 |
|
472 |
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
|
473 |
|
474 |
+
s = 3 / np # output count scaling
|
475 |
lbox *= h['giou'] * s
|
476 |
lobj *= h['obj'] * s
|
477 |
lcls *= h['cls'] * s
|
|
|
518 |
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
519 |
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
|
520 |
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
|
|
|
521 |
elif style == 'rect4':
|
522 |
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
523 |
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
|