glenn-jocher
commited on
Update loss criteria constructor (#1711)
Browse files- train.py +4 -4
- tutorial.ipynb +1 -1
- utils/loss.py +6 -6
- utils/torch_utils.py +7 -4
train.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import argparse
|
2 |
import logging
|
|
|
3 |
import os
|
4 |
import random
|
5 |
import time
|
@@ -7,7 +8,6 @@ from pathlib import Path
|
|
7 |
from threading import Thread
|
8 |
from warnings import warn
|
9 |
|
10 |
-
import math
|
11 |
import numpy as np
|
12 |
import torch.distributed as dist
|
13 |
import torch.nn as nn
|
@@ -217,7 +217,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
217 |
model.nc = nc # attach number of classes to model
|
218 |
model.hyp = hyp # attach hyperparameters to model
|
219 |
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
|
220 |
-
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
221 |
model.names = names
|
222 |
|
223 |
# Start training
|
@@ -238,7 +238,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
238 |
if opt.image_weights:
|
239 |
# Generate indices
|
240 |
if rank in [-1, 0]:
|
241 |
-
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
242 |
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
|
243 |
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
|
244 |
# Broadcast if DDP
|
@@ -330,7 +330,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
330 |
if rank in [-1, 0]:
|
331 |
# mAP
|
332 |
if ema:
|
333 |
-
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
334 |
final_epoch = epoch + 1 == epochs
|
335 |
if not opt.notest or final_epoch: # Calculate mAP
|
336 |
results, maps, times = test.test(opt.data,
|
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
+
import math
|
4 |
import os
|
5 |
import random
|
6 |
import time
|
|
|
8 |
from threading import Thread
|
9 |
from warnings import warn
|
10 |
|
|
|
11 |
import numpy as np
|
12 |
import torch.distributed as dist
|
13 |
import torch.nn as nn
|
|
|
217 |
model.nc = nc # attach number of classes to model
|
218 |
model.hyp = hyp # attach hyperparameters to model
|
219 |
model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou)
|
220 |
+
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
|
221 |
model.names = names
|
222 |
|
223 |
# Start training
|
|
|
238 |
if opt.image_weights:
|
239 |
# Generate indices
|
240 |
if rank in [-1, 0]:
|
241 |
+
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
|
242 |
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
|
243 |
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
|
244 |
# Broadcast if DDP
|
|
|
330 |
if rank in [-1, 0]:
|
331 |
# mAP
|
332 |
if ema:
|
333 |
+
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
334 |
final_epoch = epoch + 1 == epochs
|
335 |
if not opt.notest or final_epoch: # Calculate mAP
|
336 |
results, maps, times = test.test(opt.data,
|
tutorial.ipynb
CHANGED
@@ -1199,7 +1199,7 @@
|
|
1199 |
"\n",
|
1200 |
"m1 = lambda x: x * torch.sigmoid(x)\n",
|
1201 |
"m2 = torch.nn.SiLU()\n",
|
1202 |
-
"profile(x=torch.randn(16, 3, 640, 640), [m1, m2], n=100)"
|
1203 |
],
|
1204 |
"execution_count": null,
|
1205 |
"outputs": []
|
|
|
1199 |
"\n",
|
1200 |
"m1 = lambda x: x * torch.sigmoid(x)\n",
|
1201 |
"m2 = torch.nn.SiLU()\n",
|
1202 |
+
"profile(x=torch.randn(16, 3, 640, 640), ops=[m1, m2], n=100)"
|
1203 |
],
|
1204 |
"execution_count": null,
|
1205 |
"outputs": []
|
utils/loss.py
CHANGED
@@ -57,8 +57,8 @@ class FocalLoss(nn.Module):
|
|
57 |
return loss.sum()
|
58 |
else: # 'none'
|
59 |
return loss
|
60 |
-
|
61 |
-
|
62 |
class QFocalLoss(nn.Module):
|
63 |
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
64 |
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
@@ -71,7 +71,7 @@ class QFocalLoss(nn.Module):
|
|
71 |
|
72 |
def forward(self, pred, true):
|
73 |
loss = self.loss_fcn(pred, true)
|
74 |
-
|
75 |
pred_prob = torch.sigmoid(pred) # prob from logits
|
76 |
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
77 |
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
|
@@ -92,8 +92,8 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|
92 |
h = model.hyp # hyperparameters
|
93 |
|
94 |
# Define criteria
|
95 |
-
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.
|
96 |
-
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.
|
97 |
|
98 |
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
99 |
cp, cn = smooth_BCE(eps=0.0)
|
@@ -119,7 +119,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|
119 |
# Regression
|
120 |
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
121 |
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
122 |
-
pbox = torch.cat((pxy, pwh), 1)
|
123 |
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
124 |
lbox += (1.0 - iou).mean() # iou loss
|
125 |
|
|
|
57 |
return loss.sum()
|
58 |
else: # 'none'
|
59 |
return loss
|
60 |
+
|
61 |
+
|
62 |
class QFocalLoss(nn.Module):
|
63 |
# Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
|
64 |
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
|
|
71 |
|
72 |
def forward(self, pred, true):
|
73 |
loss = self.loss_fcn(pred, true)
|
74 |
+
|
75 |
pred_prob = torch.sigmoid(pred) # prob from logits
|
76 |
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
77 |
modulating_factor = torch.abs(true - pred_prob) ** self.gamma
|
|
|
92 |
h = model.hyp # hyperparameters
|
93 |
|
94 |
# Define criteria
|
95 |
+
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) # weight=model.class_weights)
|
96 |
+
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
|
97 |
|
98 |
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
99 |
cp, cn = smooth_BCE(eps=0.0)
|
|
|
119 |
# Regression
|
120 |
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
121 |
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
122 |
+
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
123 |
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
124 |
lbox += (1.0 - iou).mean() # iou loss
|
125 |
|
utils/torch_utils.py
CHANGED
@@ -81,8 +81,8 @@ def profile(x, ops, n=100, device=None):
|
|
81 |
# m1 = lambda x: x * torch.sigmoid(x)
|
82 |
# m2 = nn.SiLU()
|
83 |
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
84 |
-
|
85 |
-
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
86 |
x = x.to(device)
|
87 |
x.requires_grad = True
|
88 |
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
@@ -99,8 +99,11 @@ def profile(x, ops, n=100, device=None):
|
|
99 |
t[0] = time_synchronized()
|
100 |
y = m(x)
|
101 |
t[1] = time_synchronized()
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
104 |
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
105 |
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
106 |
|
|
|
81 |
# m1 = lambda x: x * torch.sigmoid(x)
|
82 |
# m2 = nn.SiLU()
|
83 |
# profile(x, [m1, m2], n=100) # profile speed over 100 iterations
|
84 |
+
|
85 |
+
device = device or torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
86 |
x = x.to(device)
|
87 |
x.requires_grad = True
|
88 |
print(torch.__version__, device.type, torch.cuda.get_device_properties(0) if device.type == 'cuda' else '')
|
|
|
99 |
t[0] = time_synchronized()
|
100 |
y = m(x)
|
101 |
t[1] = time_synchronized()
|
102 |
+
try:
|
103 |
+
_ = y.sum().backward()
|
104 |
+
t[2] = time_synchronized()
|
105 |
+
except: # no backward method
|
106 |
+
t[2] = float('nan')
|
107 |
dtf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
108 |
dtb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
109 |
|