single command --resume (#756)
Browse files* single command --resume
* else check files, remove TODO
* argparse.Namespace()
* tensorboard lr
* bug fix in get_latest_run()
- train.py +30 -25
- utils/general.py +1 -1
train.py
CHANGED
@@ -42,7 +42,6 @@ def train(hyp, opt, device, tb_writer=None):
|
|
42 |
epochs, batch_size, total_batch_size, weights, rank = \
|
43 |
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
44 |
|
45 |
-
# TODO: Use DDP logging. Only the first process is allowed to log.
|
46 |
# Save run settings
|
47 |
with open(log_dir / 'hyp.yaml', 'w') as f:
|
48 |
yaml.dump(hyp, f, sort_keys=False)
|
@@ -130,6 +129,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|
130 |
|
131 |
# Epochs
|
132 |
start_epoch = ckpt['epoch'] + 1
|
|
|
|
|
133 |
if epochs < start_epoch:
|
134 |
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
135 |
(weights, ckpt['epoch'], epochs))
|
@@ -158,19 +159,19 @@ def train(hyp, opt, device, tb_writer=None):
|
|
158 |
model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
|
159 |
|
160 |
# Trainloader
|
161 |
-
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
162 |
-
cache=opt.cache_images, rect=opt.rect, rank=rank,
|
163 |
world_size=opt.world_size, workers=opt.workers)
|
164 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
165 |
nb = len(dataloader) # number of batches
|
|
|
166 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
167 |
|
168 |
# Testloader
|
169 |
if rank in [-1, 0]:
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
workers=opt.workers)[0]
|
174 |
|
175 |
# Model parameters
|
176 |
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
@@ -283,7 +284,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
283 |
scaler.step(optimizer) # optimizer.step
|
284 |
scaler.update()
|
285 |
optimizer.zero_grad()
|
286 |
-
if ema
|
287 |
ema.update(model)
|
288 |
|
289 |
# Print
|
@@ -305,12 +306,13 @@ def train(hyp, opt, device, tb_writer=None):
|
|
305 |
# end batch ------------------------------------------------------------------------------------------------
|
306 |
|
307 |
# Scheduler
|
|
|
308 |
scheduler.step()
|
309 |
|
310 |
# DDP process 0 or single-GPU
|
311 |
if rank in [-1, 0]:
|
312 |
# mAP
|
313 |
-
if ema
|
314 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
315 |
final_epoch = epoch + 1 == epochs
|
316 |
if not opt.notest or final_epoch: # Calculate mAP
|
@@ -330,10 +332,11 @@ def train(hyp, opt, device, tb_writer=None):
|
|
330 |
|
331 |
# Tensorboard
|
332 |
if tb_writer:
|
333 |
-
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss',
|
334 |
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
|
335 |
-
'val/giou_loss', 'val/obj_loss', 'val/cls_loss'
|
336 |
-
|
|
|
337 |
tb_writer.add_scalar(tag, x, epoch)
|
338 |
|
339 |
# Update best mAP
|
@@ -389,8 +392,7 @@ if __name__ == '__main__':
|
|
389 |
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
|
390 |
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
|
391 |
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
392 |
-
parser.add_argument('--resume', nargs='?', const=
|
393 |
-
help='resume from given path/last.pt, or most recent run if blank')
|
394 |
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
395 |
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
396 |
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
|
@@ -413,21 +415,24 @@ if __name__ == '__main__':
|
|
413 |
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
414 |
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
|
415 |
set_logging(opt.global_rank)
|
416 |
-
|
417 |
-
# Resume
|
418 |
-
if opt.resume:
|
419 |
-
last = get_latest_run() if opt.resume == 'get_last' else opt.resume # resume from most recent run
|
420 |
-
if last and not opt.weights:
|
421 |
-
logger.info(f'Resuming training from {last}')
|
422 |
-
opt.weights = last if opt.resume and not opt.weights else opt.weights
|
423 |
if opt.global_rank in [-1, 0]:
|
424 |
check_git_status()
|
425 |
|
426 |
-
|
427 |
-
|
428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
429 |
|
430 |
-
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
431 |
device = select_device(opt.device, batch_size=opt.batch_size)
|
432 |
|
433 |
# DDP mode
|
|
|
42 |
epochs, batch_size, total_batch_size, weights, rank = \
|
43 |
opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
|
44 |
|
|
|
45 |
# Save run settings
|
46 |
with open(log_dir / 'hyp.yaml', 'w') as f:
|
47 |
yaml.dump(hyp, f, sort_keys=False)
|
|
|
129 |
|
130 |
# Epochs
|
131 |
start_epoch = ckpt['epoch'] + 1
|
132 |
+
if opt.resume:
|
133 |
+
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
|
134 |
if epochs < start_epoch:
|
135 |
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
136 |
(weights, ckpt['epoch'], epochs))
|
|
|
159 |
model = DDP(model, device_ids=[opt.local_rank], output_device=(opt.local_rank))
|
160 |
|
161 |
# Trainloader
|
162 |
+
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
|
163 |
+
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
|
164 |
world_size=opt.world_size, workers=opt.workers)
|
165 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
166 |
nb = len(dataloader) # number of batches
|
167 |
+
ema.updates = start_epoch * nb // accumulate # set EMA updates
|
168 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
169 |
|
170 |
# Testloader
|
171 |
if rank in [-1, 0]:
|
172 |
+
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt,
|
173 |
+
hyp=hyp, augment=False, cache=opt.cache_images, rect=True, rank=-1,
|
174 |
+
world_size=opt.world_size, workers=opt.workers)[0] # only runs on process 0
|
|
|
175 |
|
176 |
# Model parameters
|
177 |
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
|
|
284 |
scaler.step(optimizer) # optimizer.step
|
285 |
scaler.update()
|
286 |
optimizer.zero_grad()
|
287 |
+
if ema:
|
288 |
ema.update(model)
|
289 |
|
290 |
# Print
|
|
|
306 |
# end batch ------------------------------------------------------------------------------------------------
|
307 |
|
308 |
# Scheduler
|
309 |
+
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
|
310 |
scheduler.step()
|
311 |
|
312 |
# DDP process 0 or single-GPU
|
313 |
if rank in [-1, 0]:
|
314 |
# mAP
|
315 |
+
if ema:
|
316 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
|
317 |
final_epoch = epoch + 1 == epochs
|
318 |
if not opt.notest or final_epoch: # Calculate mAP
|
|
|
332 |
|
333 |
# Tensorboard
|
334 |
if tb_writer:
|
335 |
+
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss
|
336 |
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95',
|
337 |
+
'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss
|
338 |
+
'x/lr0', 'x/lr1', 'x/lr2'] # params
|
339 |
+
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
340 |
tb_writer.add_scalar(tag, x, epoch)
|
341 |
|
342 |
# Update best mAP
|
|
|
392 |
parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
|
393 |
parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='train,test sizes')
|
394 |
parser.add_argument('--rect', action='store_true', help='rectangular training')
|
395 |
+
parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
|
|
|
396 |
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
397 |
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
398 |
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
|
|
|
415 |
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
|
416 |
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1
|
417 |
set_logging(opt.global_rank)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
if opt.global_rank in [-1, 0]:
|
419 |
check_git_status()
|
420 |
|
421 |
+
# Resume
|
422 |
+
if opt.resume: # resume an interrupted run
|
423 |
+
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
|
424 |
+
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
425 |
+
with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
|
426 |
+
opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace
|
427 |
+
opt.cfg, opt.weights, opt.resume = '', ckpt, True
|
428 |
+
logger.info('Resuming training from %s' % ckpt)
|
429 |
+
|
430 |
+
else:
|
431 |
+
opt.hyp = opt.hyp or ('data/hyp.finetune.yaml' if opt.weights else 'data/hyp.scratch.yaml')
|
432 |
+
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
|
433 |
+
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
|
434 |
+
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
|
435 |
|
|
|
436 |
device = select_device(opt.device, batch_size=opt.batch_size)
|
437 |
|
438 |
# DDP mode
|
utils/general.py
CHANGED
@@ -61,7 +61,7 @@ def init_seeds(seed=0):
|
|
61 |
def get_latest_run(search_dir='./runs'):
|
62 |
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
63 |
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
64 |
-
return max(last_list, key=os.path.getctime)
|
65 |
|
66 |
|
67 |
def check_git_status():
|
|
|
61 |
def get_latest_run(search_dir='./runs'):
|
62 |
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
63 |
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
64 |
+
return max(last_list, key=os.path.getctime) if last_list else ''
|
65 |
|
66 |
|
67 |
def check_git_status():
|