PyTorch 1.6.0 update with native AMP (#573)
Browse files* PyTorch have Automatic Mixed Precision (AMP) Training.
* Fixed the problem of inconsistent code length indentation
* Fixed the problem of inconsistent code length indentation
* Mixed precision training is turned on by default
- train.py +36 -44
- utils/torch_utils.py +2 -2
train.py
CHANGED
@@ -5,6 +5,7 @@ import torch.nn.functional as F
|
|
5 |
import torch.optim as optim
|
6 |
import torch.optim.lr_scheduler as lr_scheduler
|
7 |
import torch.utils.data
|
|
|
8 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
9 |
from torch.utils.tensorboard import SummaryWriter
|
10 |
|
@@ -14,13 +15,6 @@ from utils import google_utils
|
|
14 |
from utils.datasets import *
|
15 |
from utils.utils import *
|
16 |
|
17 |
-
mixed_precision = True
|
18 |
-
try: # Mixed precision training https://github.com/NVIDIA/apex
|
19 |
-
from apex import amp
|
20 |
-
except:
|
21 |
-
print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
|
22 |
-
mixed_precision = False # not installed
|
23 |
-
|
24 |
# Hyperparameters
|
25 |
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
|
26 |
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
|
@@ -63,6 +57,7 @@ def train(hyp, tb_writer, opt, device):
|
|
63 |
yaml.dump(vars(opt), f, sort_keys=False)
|
64 |
|
65 |
# Configure
|
|
|
66 |
init_seeds(2 + rank)
|
67 |
with open(opt.data) as f:
|
68 |
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
@@ -113,7 +108,7 @@ def train(hyp, tb_writer, opt, device):
|
|
113 |
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
|
114 |
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
|
115 |
del pg0, pg1, pg2
|
116 |
-
|
117 |
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
118 |
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
|
119 |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
@@ -160,16 +155,12 @@ def train(hyp, tb_writer, opt, device):
|
|
160 |
|
161 |
del ckpt
|
162 |
|
163 |
-
# Mixed precision training https://github.com/NVIDIA/apex
|
164 |
-
if mixed_precision:
|
165 |
-
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
166 |
-
|
167 |
# DP mode
|
168 |
-
if
|
169 |
model = torch.nn.DataParallel(model)
|
170 |
|
171 |
# SyncBatchNorm
|
172 |
-
if opt.sync_bn and
|
173 |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
174 |
print('Using SyncBatchNorm()')
|
175 |
|
@@ -177,7 +168,7 @@ def train(hyp, tb_writer, opt, device):
|
|
177 |
ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None
|
178 |
|
179 |
# DDP mode
|
180 |
-
if
|
181 |
model = DDP(model, device_ids=[rank], output_device=rank)
|
182 |
|
183 |
# Trainloader
|
@@ -223,6 +214,7 @@ def train(hyp, tb_writer, opt, device):
|
|
223 |
maps = np.zeros(nc) # mAP per class
|
224 |
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
225 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
|
|
226 |
if rank in [0, -1]:
|
227 |
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
228 |
print('Using %g dataloader workers' % dataloader.num_workers)
|
@@ -232,15 +224,14 @@ def train(hyp, tb_writer, opt, device):
|
|
232 |
model.train()
|
233 |
|
234 |
# Update image weights (optional)
|
235 |
-
# When in DDP mode, the generated indices will be broadcasted to synchronize dataset.
|
236 |
if dataset.image_weights:
|
237 |
-
# Generate indices
|
238 |
if rank in [-1, 0]:
|
239 |
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
240 |
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
241 |
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
|
242 |
k=dataset.n) # rand weighted idx
|
243 |
-
# Broadcast
|
244 |
if rank != -1:
|
245 |
indices = torch.zeros([dataset.n], dtype=torch.int)
|
246 |
if rank == 0:
|
@@ -263,7 +254,7 @@ def train(hyp, tb_writer, opt, device):
|
|
263 |
optimizer.zero_grad()
|
264 |
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
265 |
ni = i + nb * epoch # number integrated batches (since train start)
|
266 |
-
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0
|
267 |
|
268 |
# Warmup
|
269 |
if ni <= nw:
|
@@ -284,27 +275,26 @@ def train(hyp, tb_writer, opt, device):
|
|
284 |
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
285 |
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
286 |
|
287 |
-
#
|
288 |
-
|
|
|
|
|
289 |
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
print('WARNING: non-finite loss, ending training ', loss_items)
|
296 |
-
return results
|
297 |
|
298 |
# Backward
|
299 |
-
|
300 |
-
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
301 |
-
scaled_loss.backward()
|
302 |
-
else:
|
303 |
-
loss.backward()
|
304 |
|
305 |
# Optimize
|
306 |
if ni % accumulate == 0:
|
307 |
-
|
|
|
308 |
optimizer.zero_grad()
|
309 |
if ema is not None:
|
310 |
ema.update(model)
|
@@ -312,7 +302,7 @@ def train(hyp, tb_writer, opt, device):
|
|
312 |
# Print
|
313 |
if rank in [-1, 0]:
|
314 |
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
315 |
-
mem = '%.3gG' % (torch.cuda.
|
316 |
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
317 |
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
|
318 |
pbar.set_description(s)
|
@@ -330,7 +320,7 @@ def train(hyp, tb_writer, opt, device):
|
|
330 |
# Scheduler
|
331 |
scheduler.step()
|
332 |
|
333 |
-
#
|
334 |
if rank in [-1, 0]:
|
335 |
# mAP
|
336 |
if ema is not None:
|
@@ -377,7 +367,7 @@ def train(hyp, tb_writer, opt, device):
|
|
377 |
|
378 |
# Save last, best and delete
|
379 |
torch.save(ckpt, last)
|
380 |
-
if best_fitness == fi:
|
381 |
torch.save(ckpt, best)
|
382 |
del ckpt
|
383 |
# end epoch ----------------------------------------------------------------------------------------------------
|
@@ -429,10 +419,12 @@ if __name__ == '__main__':
|
|
429 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
430 |
opt = parser.parse_args()
|
431 |
|
|
|
432 |
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
433 |
if last and not opt.weights:
|
434 |
print(f'Resuming training from {last}')
|
435 |
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
|
|
436 |
if opt.local_rank in [-1, 0]:
|
437 |
check_git_status()
|
438 |
opt.cfg = check_file(opt.cfg) # check file
|
@@ -442,21 +434,20 @@ if __name__ == '__main__':
|
|
442 |
with open(opt.hyp) as f:
|
443 |
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
|
444 |
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
445 |
-
device = torch_utils.select_device(opt.device,
|
446 |
opt.total_batch_size = opt.batch_size
|
447 |
opt.world_size = 1
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
# DDP mode
|
452 |
assert torch.cuda.device_count() > opt.local_rank
|
453 |
torch.cuda.set_device(opt.local_rank)
|
454 |
device = torch.device("cuda", opt.local_rank)
|
455 |
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
|
456 |
-
|
457 |
opt.world_size = dist.get_world_size()
|
458 |
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
|
459 |
opt.batch_size = opt.total_batch_size // opt.world_size
|
|
|
460 |
print(opt)
|
461 |
|
462 |
# Train
|
@@ -466,11 +457,12 @@ if __name__ == '__main__':
|
|
466 |
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
|
467 |
else:
|
468 |
tb_writer = None
|
|
|
469 |
train(hyp, tb_writer, opt, device)
|
470 |
|
471 |
# Evolve hyperparameters (optional)
|
472 |
else:
|
473 |
-
assert opt.local_rank == -1,
|
474 |
|
475 |
tb_writer = None
|
476 |
opt.notest, opt.nosave = True, True # only test/save final epoch
|
|
|
5 |
import torch.optim as optim
|
6 |
import torch.optim.lr_scheduler as lr_scheduler
|
7 |
import torch.utils.data
|
8 |
+
from torch.cuda import amp
|
9 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
10 |
from torch.utils.tensorboard import SummaryWriter
|
11 |
|
|
|
15 |
from utils.datasets import *
|
16 |
from utils.utils import *
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
# Hyperparameters
|
19 |
hyp = {'optimizer': 'SGD', # ['adam', 'SGD', None] if none, default is SGD
|
20 |
'lr0': 0.01, # initial learning rate (SGD=1E-2, Adam=1E-3)
|
|
|
57 |
yaml.dump(vars(opt), f, sort_keys=False)
|
58 |
|
59 |
# Configure
|
60 |
+
cuda = device.type != 'cpu'
|
61 |
init_seeds(2 + rank)
|
62 |
with open(opt.data) as f:
|
63 |
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
|
|
108 |
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
|
109 |
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
|
110 |
del pg0, pg1, pg2
|
111 |
+
|
112 |
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
113 |
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
|
114 |
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
|
|
155 |
|
156 |
del ckpt
|
157 |
|
|
|
|
|
|
|
|
|
158 |
# DP mode
|
159 |
+
if cuda and rank == -1 and torch.cuda.device_count() > 1:
|
160 |
model = torch.nn.DataParallel(model)
|
161 |
|
162 |
# SyncBatchNorm
|
163 |
+
if opt.sync_bn and cuda and rank != -1:
|
164 |
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
|
165 |
print('Using SyncBatchNorm()')
|
166 |
|
|
|
168 |
ema = torch_utils.ModelEMA(model) if rank in [-1, 0] else None
|
169 |
|
170 |
# DDP mode
|
171 |
+
if cuda and rank != -1:
|
172 |
model = DDP(model, device_ids=[rank], output_device=rank)
|
173 |
|
174 |
# Trainloader
|
|
|
214 |
maps = np.zeros(nc) # mAP per class
|
215 |
results = (0, 0, 0, 0, 0, 0, 0) # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
|
216 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
217 |
+
scaler = amp.GradScaler(enabled=cuda)
|
218 |
if rank in [0, -1]:
|
219 |
print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
|
220 |
print('Using %g dataloader workers' % dataloader.num_workers)
|
|
|
224 |
model.train()
|
225 |
|
226 |
# Update image weights (optional)
|
|
|
227 |
if dataset.image_weights:
|
228 |
+
# Generate indices
|
229 |
if rank in [-1, 0]:
|
230 |
w = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights
|
231 |
image_weights = labels_to_image_weights(dataset.labels, nc=nc, class_weights=w)
|
232 |
dataset.indices = random.choices(range(dataset.n), weights=image_weights,
|
233 |
k=dataset.n) # rand weighted idx
|
234 |
+
# Broadcast if DDP
|
235 |
if rank != -1:
|
236 |
indices = torch.zeros([dataset.n], dtype=torch.int)
|
237 |
if rank == 0:
|
|
|
254 |
optimizer.zero_grad()
|
255 |
for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
|
256 |
ni = i + nb * epoch # number integrated batches (since train start)
|
257 |
+
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0-255 to 0.0-1.0
|
258 |
|
259 |
# Warmup
|
260 |
if ni <= nw:
|
|
|
275 |
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
276 |
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
277 |
|
278 |
+
# Autocast
|
279 |
+
with amp.autocast():
|
280 |
+
# Forward
|
281 |
+
pred = model(imgs)
|
282 |
|
283 |
+
# Loss
|
284 |
+
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
|
285 |
+
if rank != -1:
|
286 |
+
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
287 |
+
# if not torch.isfinite(loss):
|
288 |
+
# print('WARNING: non-finite loss, ending training ', loss_items)
|
289 |
+
# return results
|
290 |
|
291 |
# Backward
|
292 |
+
scaler.scale(loss).backward()
|
|
|
|
|
|
|
|
|
293 |
|
294 |
# Optimize
|
295 |
if ni % accumulate == 0:
|
296 |
+
scaler.step(optimizer) # optimizer.step
|
297 |
+
scaler.update()
|
298 |
optimizer.zero_grad()
|
299 |
if ema is not None:
|
300 |
ema.update(model)
|
|
|
302 |
# Print
|
303 |
if rank in [-1, 0]:
|
304 |
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
|
305 |
+
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
306 |
s = ('%10s' * 2 + '%10.4g' * 6) % (
|
307 |
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
|
308 |
pbar.set_description(s)
|
|
|
320 |
# Scheduler
|
321 |
scheduler.step()
|
322 |
|
323 |
+
# DDP process 0 or single-GPU
|
324 |
if rank in [-1, 0]:
|
325 |
# mAP
|
326 |
if ema is not None:
|
|
|
367 |
|
368 |
# Save last, best and delete
|
369 |
torch.save(ckpt, last)
|
370 |
+
if best_fitness == fi:
|
371 |
torch.save(ckpt, best)
|
372 |
del ckpt
|
373 |
# end epoch ----------------------------------------------------------------------------------------------------
|
|
|
419 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
420 |
opt = parser.parse_args()
|
421 |
|
422 |
+
# Resume
|
423 |
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
424 |
if last and not opt.weights:
|
425 |
print(f'Resuming training from {last}')
|
426 |
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
427 |
+
|
428 |
if opt.local_rank in [-1, 0]:
|
429 |
check_git_status()
|
430 |
opt.cfg = check_file(opt.cfg) # check file
|
|
|
434 |
with open(opt.hyp) as f:
|
435 |
hyp.update(yaml.load(f, Loader=yaml.FullLoader)) # update hyps
|
436 |
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
437 |
+
device = torch_utils.select_device(opt.device, batch_size=opt.batch_size)
|
438 |
opt.total_batch_size = opt.batch_size
|
439 |
opt.world_size = 1
|
440 |
+
|
441 |
+
# DDP mode
|
442 |
+
if opt.local_rank != -1:
|
|
|
443 |
assert torch.cuda.device_count() > opt.local_rank
|
444 |
torch.cuda.set_device(opt.local_rank)
|
445 |
device = torch.device("cuda", opt.local_rank)
|
446 |
dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
|
|
|
447 |
opt.world_size = dist.get_world_size()
|
448 |
assert opt.batch_size % opt.world_size == 0, "Batch size is not a multiple of the number of devices given!"
|
449 |
opt.batch_size = opt.total_batch_size // opt.world_size
|
450 |
+
|
451 |
print(opt)
|
452 |
|
453 |
# Train
|
|
|
457 |
tb_writer = SummaryWriter(log_dir=increment_dir('runs/exp', opt.name))
|
458 |
else:
|
459 |
tb_writer = None
|
460 |
+
|
461 |
train(hyp, tb_writer, opt, device)
|
462 |
|
463 |
# Evolve hyperparameters (optional)
|
464 |
else:
|
465 |
+
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
|
466 |
|
467 |
tb_writer = None
|
468 |
opt.notest, opt.nosave = True, True # only test/save final epoch
|
utils/torch_utils.py
CHANGED
@@ -22,7 +22,7 @@ def init_seeds(seed=0):
|
|
22 |
cudnn.benchmark = True
|
23 |
|
24 |
|
25 |
-
def select_device(device='',
|
26 |
# device = 'cpu' or '0' or '0,1,2,3'
|
27 |
cpu_request = device.lower() == 'cpu'
|
28 |
if device and not cpu_request: # if device requested other than 'cpu'
|
@@ -36,7 +36,7 @@ def select_device(device='', apex=False, batch_size=None):
|
|
36 |
if ng > 1 and batch_size: # check that batch_size is compatible with device_count
|
37 |
assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
|
38 |
x = [torch.cuda.get_device_properties(i) for i in range(ng)]
|
39 |
-
s = 'Using CUDA '
|
40 |
for i in range(0, ng):
|
41 |
if i == 1:
|
42 |
s = ' ' * len(s)
|
|
|
22 |
cudnn.benchmark = True
|
23 |
|
24 |
|
25 |
+
def select_device(device='', batch_size=None):
|
26 |
# device = 'cpu' or '0' or '0,1,2,3'
|
27 |
cpu_request = device.lower() == 'cpu'
|
28 |
if device and not cpu_request: # if device requested other than 'cpu'
|
|
|
36 |
if ng > 1 and batch_size: # check that batch_size is compatible with device_count
|
37 |
assert batch_size % ng == 0, 'batch-size %g not multiple of GPU count %g' % (batch_size, ng)
|
38 |
x = [torch.cuda.get_device_properties(i) for i in range(ng)]
|
39 |
+
s = 'Using CUDA '
|
40 |
for i in range(0, ng):
|
41 |
if i == 1:
|
42 |
s = ' ' * len(s)
|