Add `train.run()` method (#3700)
Browse files* Update train.py explicit arguments
* Update train.py
* Add run method
train.py
CHANGED
@@ -46,8 +46,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
46 |
opt,
|
47 |
device,
|
48 |
):
|
49 |
-
save_dir, epochs, batch_size, weights, single_cls = \
|
50 |
-
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls
|
|
|
51 |
|
52 |
# Directories
|
53 |
save_dir = Path(save_dir)
|
@@ -70,34 +71,34 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
70 |
yaml.safe_dump(vars(opt), f, sort_keys=False)
|
71 |
|
72 |
# Configure
|
73 |
-
plots = not
|
74 |
cuda = device.type != 'cpu'
|
75 |
init_seeds(2 + RANK)
|
76 |
-
with open(
|
77 |
data_dict = yaml.safe_load(f) # data dict
|
78 |
|
79 |
# Loggers
|
80 |
loggers = {'wandb': None, 'tb': None} # loggers dict
|
81 |
if RANK in [-1, 0]:
|
82 |
# TensorBoard
|
83 |
-
if not
|
84 |
prefix = colorstr('tensorboard: ')
|
85 |
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
|
86 |
-
loggers['tb'] = SummaryWriter(
|
87 |
|
88 |
# W&B
|
89 |
opt.hyp = hyp # add hyperparameters
|
90 |
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
91 |
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
|
92 |
loggers['wandb'] = wandb_logger.wandb
|
93 |
-
|
94 |
-
|
95 |
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
|
96 |
|
97 |
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
98 |
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
99 |
-
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc,
|
100 |
-
is_coco =
|
101 |
|
102 |
# Model
|
103 |
pretrained = weights.endswith('.pt')
|
@@ -105,14 +106,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
105 |
with torch_distributed_zero_first(RANK):
|
106 |
weights = attempt_download(weights) # download if not found locally
|
107 |
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
108 |
-
model = Model(
|
109 |
-
exclude = ['anchor'] if (
|
110 |
state_dict = ckpt['model'].float().state_dict() # to FP32
|
111 |
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
|
112 |
model.load_state_dict(state_dict, strict=False) # load
|
113 |
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
114 |
else:
|
115 |
-
model = Model(
|
116 |
with torch_distributed_zero_first(RANK):
|
117 |
check_dataset(data_dict) # check
|
118 |
train_path = data_dict['train']
|
@@ -182,7 +183,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
182 |
|
183 |
# Epochs
|
184 |
start_epoch = ckpt['epoch'] + 1
|
185 |
-
if
|
186 |
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
|
187 |
if epochs < start_epoch:
|
188 |
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
@@ -210,20 +211,20 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
210 |
# Trainloader
|
211 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
212 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
|
213 |
-
workers=
|
214 |
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
215 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
216 |
nb = len(dataloader) # number of batches
|
217 |
-
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc,
|
218 |
|
219 |
# Process 0
|
220 |
if RANK in [-1, 0]:
|
221 |
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
|
222 |
-
hyp=hyp, cache=opt.cache_images and not
|
223 |
-
workers=
|
224 |
pad=0.5, prefix=colorstr('val: '))[0]
|
225 |
|
226 |
-
if not
|
227 |
labels = np.concatenate(dataset.labels, 0)
|
228 |
c = torch.tensor(labels[:, 0]) # classes
|
229 |
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
@@ -356,8 +357,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
356 |
with warnings.catch_warnings():
|
357 |
warnings.simplefilter('ignore') # suppress jit trace warning
|
358 |
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
359 |
-
elif plots and ni == 10 and
|
360 |
-
wandb_logger.log({'Mosaics': [
|
361 |
save_dir.glob('train*.jpg') if x.exists()]})
|
362 |
|
363 |
# end batch ------------------------------------------------------------------------------------------------
|
@@ -371,7 +372,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
371 |
# mAP
|
372 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
373 |
final_epoch = epoch + 1 == epochs
|
374 |
-
if not
|
375 |
wandb_logger.current_epoch = epoch + 1
|
376 |
results, maps, _ = test.test(data_dict,
|
377 |
batch_size=batch_size // WORLD_SIZE * 2,
|
@@ -398,7 +399,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
398 |
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
399 |
if loggers['tb']:
|
400 |
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
|
401 |
-
if
|
402 |
wandb_logger.log({tag: x}) # W&B
|
403 |
|
404 |
# Update best mAP
|
@@ -408,7 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
408 |
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
409 |
|
410 |
# Save model
|
411 |
-
if (not
|
412 |
ckpt = {'epoch': epoch,
|
413 |
'best_fitness': best_fitness,
|
414 |
'training_results': results_file.read_text(),
|
@@ -416,13 +417,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
416 |
'ema': deepcopy(ema.ema).half(),
|
417 |
'updates': ema.updates,
|
418 |
'optimizer': optimizer.state_dict(),
|
419 |
-
'wandb_id': wandb_logger.wandb_run.id if
|
420 |
|
421 |
# Save last, best and delete
|
422 |
torch.save(ckpt, last)
|
423 |
if best_fitness == fi:
|
424 |
torch.save(ckpt, best)
|
425 |
-
if
|
426 |
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
427 |
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
428 |
del ckpt
|
@@ -433,15 +434,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
433 |
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
434 |
if plots:
|
435 |
plot_results(save_dir=save_dir) # save as results.png
|
436 |
-
if
|
437 |
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
438 |
-
wandb_logger.log({"Results": [
|
439 |
if (save_dir / f).exists()]})
|
440 |
|
441 |
-
if not
|
442 |
if is_coco: # COCO dataset
|
443 |
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
|
444 |
-
results, _, _ = test.test(
|
445 |
batch_size=batch_size // WORLD_SIZE * 2,
|
446 |
imgsz=imgsz_test,
|
447 |
conf_thres=0.001,
|
@@ -457,17 +458,17 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
457 |
for f in last, best:
|
458 |
if f.exists():
|
459 |
strip_optimizer(f) # strip optimizers
|
460 |
-
if
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
wandb_logger.finish_run()
|
465 |
|
466 |
torch.cuda.empty_cache()
|
467 |
return results
|
468 |
|
469 |
|
470 |
-
def parse_opt():
|
471 |
parser = argparse.ArgumentParser()
|
472 |
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
|
473 |
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
|
@@ -503,7 +504,7 @@ def parse_opt():
|
|
503 |
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
504 |
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
505 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
506 |
-
opt = parser.parse_args()
|
507 |
return opt
|
508 |
|
509 |
|
@@ -633,6 +634,14 @@ def main(opt):
|
|
633 |
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
|
634 |
|
635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
if __name__ == "__main__":
|
637 |
opt = parse_opt()
|
638 |
main(opt)
|
|
|
46 |
opt,
|
47 |
device,
|
48 |
):
|
49 |
+
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, notest, nosave, workers, = \
|
50 |
+
opt.save_dir, opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
51 |
+
opt.resume, opt.notest, opt.nosave, opt.workers
|
52 |
|
53 |
# Directories
|
54 |
save_dir = Path(save_dir)
|
|
|
71 |
yaml.safe_dump(vars(opt), f, sort_keys=False)
|
72 |
|
73 |
# Configure
|
74 |
+
plots = not evolve # create plots
|
75 |
cuda = device.type != 'cpu'
|
76 |
init_seeds(2 + RANK)
|
77 |
+
with open(data) as f:
|
78 |
data_dict = yaml.safe_load(f) # data dict
|
79 |
|
80 |
# Loggers
|
81 |
loggers = {'wandb': None, 'tb': None} # loggers dict
|
82 |
if RANK in [-1, 0]:
|
83 |
# TensorBoard
|
84 |
+
if not evolve:
|
85 |
prefix = colorstr('tensorboard: ')
|
86 |
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
|
87 |
+
loggers['tb'] = SummaryWriter(str(save_dir))
|
88 |
|
89 |
# W&B
|
90 |
opt.hyp = hyp # add hyperparameters
|
91 |
run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
|
92 |
wandb_logger = WandbLogger(opt, save_dir.stem, run_id, data_dict)
|
93 |
loggers['wandb'] = wandb_logger.wandb
|
94 |
+
if loggers['wandb']:
|
95 |
+
data_dict = wandb_logger.data_dict
|
96 |
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming
|
97 |
|
98 |
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
|
99 |
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
|
100 |
+
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data) # check
|
101 |
+
is_coco = data.endswith('coco.yaml') and nc == 80 # COCO dataset
|
102 |
|
103 |
# Model
|
104 |
pretrained = weights.endswith('.pt')
|
|
|
106 |
with torch_distributed_zero_first(RANK):
|
107 |
weights = attempt_download(weights) # download if not found locally
|
108 |
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
109 |
+
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
110 |
+
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
|
111 |
state_dict = ckpt['model'].float().state_dict() # to FP32
|
112 |
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
|
113 |
model.load_state_dict(state_dict, strict=False) # load
|
114 |
logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
|
115 |
else:
|
116 |
+
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
117 |
with torch_distributed_zero_first(RANK):
|
118 |
check_dataset(data_dict) # check
|
119 |
train_path = data_dict['train']
|
|
|
183 |
|
184 |
# Epochs
|
185 |
start_epoch = ckpt['epoch'] + 1
|
186 |
+
if resume:
|
187 |
assert start_epoch > 0, '%s training to %g epochs is finished, nothing to resume.' % (weights, epochs)
|
188 |
if epochs < start_epoch:
|
189 |
logger.info('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
|
|
|
211 |
# Trainloader
|
212 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size // WORLD_SIZE, gs, single_cls,
|
213 |
hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
|
214 |
+
workers=workers,
|
215 |
image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
|
216 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
217 |
nb = len(dataloader) # number of batches
|
218 |
+
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, data, nc - 1)
|
219 |
|
220 |
# Process 0
|
221 |
if RANK in [-1, 0]:
|
222 |
testloader = create_dataloader(test_path, imgsz_test, batch_size // WORLD_SIZE * 2, gs, single_cls,
|
223 |
+
hyp=hyp, cache=opt.cache_images and not notest, rect=True, rank=-1,
|
224 |
+
workers=workers,
|
225 |
pad=0.5, prefix=colorstr('val: '))[0]
|
226 |
|
227 |
+
if not resume:
|
228 |
labels = np.concatenate(dataset.labels, 0)
|
229 |
c = torch.tensor(labels[:, 0]) # classes
|
230 |
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
|
|
357 |
with warnings.catch_warnings():
|
358 |
warnings.simplefilter('ignore') # suppress jit trace warning
|
359 |
loggers['tb'].add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), [])
|
360 |
+
elif plots and ni == 10 and loggers['wandb']:
|
361 |
+
wandb_logger.log({'Mosaics': [loggers['wandb'].Image(str(x), caption=x.name) for x in
|
362 |
save_dir.glob('train*.jpg') if x.exists()]})
|
363 |
|
364 |
# end batch ------------------------------------------------------------------------------------------------
|
|
|
372 |
# mAP
|
373 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
|
374 |
final_epoch = epoch + 1 == epochs
|
375 |
+
if not notest or final_epoch: # Calculate mAP
|
376 |
wandb_logger.current_epoch = epoch + 1
|
377 |
results, maps, _ = test.test(data_dict,
|
378 |
batch_size=batch_size // WORLD_SIZE * 2,
|
|
|
399 |
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags):
|
400 |
if loggers['tb']:
|
401 |
loggers['tb'].add_scalar(tag, x, epoch) # TensorBoard
|
402 |
+
if loggers['wandb']:
|
403 |
wandb_logger.log({tag: x}) # W&B
|
404 |
|
405 |
# Update best mAP
|
|
|
409 |
wandb_logger.end_epoch(best_result=best_fitness == fi)
|
410 |
|
411 |
# Save model
|
412 |
+
if (not nosave) or (final_epoch and not evolve): # if save
|
413 |
ckpt = {'epoch': epoch,
|
414 |
'best_fitness': best_fitness,
|
415 |
'training_results': results_file.read_text(),
|
|
|
417 |
'ema': deepcopy(ema.ema).half(),
|
418 |
'updates': ema.updates,
|
419 |
'optimizer': optimizer.state_dict(),
|
420 |
+
'wandb_id': wandb_logger.wandb_run.id if loggers['wandb'] else None}
|
421 |
|
422 |
# Save last, best and delete
|
423 |
torch.save(ckpt, last)
|
424 |
if best_fitness == fi:
|
425 |
torch.save(ckpt, best)
|
426 |
+
if loggers['wandb']:
|
427 |
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
428 |
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
429 |
del ckpt
|
|
|
434 |
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
435 |
if plots:
|
436 |
plot_results(save_dir=save_dir) # save as results.png
|
437 |
+
if loggers['wandb']:
|
438 |
files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
|
439 |
+
wandb_logger.log({"Results": [loggers['wandb'].Image(str(save_dir / f), caption=f) for f in files
|
440 |
if (save_dir / f).exists()]})
|
441 |
|
442 |
+
if not evolve:
|
443 |
if is_coco: # COCO dataset
|
444 |
for m in [last, best] if best.exists() else [last]: # speed, mAP tests
|
445 |
+
results, _, _ = test.test(data,
|
446 |
batch_size=batch_size // WORLD_SIZE * 2,
|
447 |
imgsz=imgsz_test,
|
448 |
conf_thres=0.001,
|
|
|
458 |
for f in last, best:
|
459 |
if f.exists():
|
460 |
strip_optimizer(f) # strip optimizers
|
461 |
+
if loggers['wandb']: # Log the stripped model
|
462 |
+
loggers['wandb'].log_artifact(str(best if best.exists() else last), type='model',
|
463 |
+
name='run_' + wandb_logger.wandb_run.id + '_model',
|
464 |
+
aliases=['latest', 'best', 'stripped'])
|
465 |
wandb_logger.finish_run()
|
466 |
|
467 |
torch.cuda.empty_cache()
|
468 |
return results
|
469 |
|
470 |
|
471 |
+
def parse_opt(known=False):
|
472 |
parser = argparse.ArgumentParser()
|
473 |
parser.add_argument('--weights', type=str, default='yolov5s.pt', help='initial weights path')
|
474 |
parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
|
|
|
504 |
parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
|
505 |
parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
|
506 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
507 |
+
opt = parser.parse_known_args()[0] if known else parser.parse_args()
|
508 |
return opt
|
509 |
|
510 |
|
|
|
634 |
f'Command to train a new model with these hyperparameters: $ python train.py --hyp {yaml_file}')
|
635 |
|
636 |
|
637 |
+
def run(**kwargs):
|
638 |
+
# Usage: import train; train.run(imgsz=320, weights='yolov5m.pt')
|
639 |
+
opt = parse_opt(True)
|
640 |
+
for k, v in kwargs.items():
|
641 |
+
setattr(opt, k, v)
|
642 |
+
main(opt)
|
643 |
+
|
644 |
+
|
645 |
if __name__ == "__main__":
|
646 |
opt = parse_opt()
|
647 |
main(opt)
|