Ayush Chaurasia commited on
Commit
518c095
·
unverified ·
1 Parent(s): dc51e80

W&B resume ddp from run link fix (#2579)

Browse files

* W&B resume ddp from run link fix

* Native DDP W&B support for training, resuming

Files changed (2) hide show
  1. train.py +2 -2
  2. utils/wandb_logging/wandb_utils.py +52 -14
train.py CHANGED
@@ -33,7 +33,7 @@ from utils.google_utils import attempt_download
33
  from utils.loss import ComputeLoss
34
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
35
  from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
36
- from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id
37
 
38
  logger = logging.getLogger(__name__)
39
 
@@ -496,7 +496,7 @@ if __name__ == '__main__':
496
  check_requirements()
497
 
498
  # Resume
499
- wandb_run = resume_and_get_id(opt)
500
  if opt.resume and not wandb_run: # resume an interrupted run
501
  ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
502
  assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
 
33
  from utils.loss import ComputeLoss
34
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
35
  from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
36
+ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
37
 
38
  logger = logging.getLogger(__name__)
39
 
 
496
  check_requirements()
497
 
498
  # Resume
499
+ wandb_run = check_wandb_resume(opt)
500
  if opt.resume and not wandb_run: # resume an interrupted run
501
  ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
502
  assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
utils/wandb_logging/wandb_utils.py CHANGED
@@ -23,7 +23,7 @@ except ImportError:
23
  WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
24
 
25
 
26
- def remove_prefix(from_string, prefix):
27
  return from_string[len(prefix):]
28
 
29
 
@@ -33,35 +33,73 @@ def check_wandb_config_file(data_config_file):
33
  return wandb_config
34
  return data_config_file
35
 
 
 
 
 
 
 
36
 
37
- def resume_and_get_id(opt):
38
- # It's more elegant to stick to 1 wandb.init call, but as useful config data is overwritten in the WandbLogger's wandb.init call
39
  if isinstance(opt.resume, str):
40
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
41
- run_path = Path(remove_prefix(opt.resume, WANDB_ARTIFACT_PREFIX))
42
- run_id = run_path.stem
43
- project = run_path.parent.stem
44
- model_artifact_name = WANDB_ARTIFACT_PREFIX + 'run_' + run_id + '_model'
45
- assert wandb, 'install wandb to resume wandb runs'
46
- # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
47
- run = wandb.init(id=run_id, project=project, resume='allow')
48
- opt.resume = model_artifact_name
49
- return run
50
  return None
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  class WandbLogger():
54
  def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
55
  # Pre-training routine --
56
  self.job_type = job_type
57
  self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
58
- if self.wandb:
 
 
 
 
 
 
 
 
 
59
  self.wandb_run = wandb.init(config=opt,
60
  resume="allow",
61
  project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
62
  name=name,
63
  job_type=job_type,
64
- id=run_id) if not wandb.run else wandb.run
 
65
  if self.job_type == 'Training':
66
  if not opt.resume:
67
  wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
 
23
  WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
24
 
25
 
26
+ def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
27
  return from_string[len(prefix):]
28
 
29
 
 
33
  return wandb_config
34
  return data_config_file
35
 
36
+ def get_run_info(run_path):
37
+ run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
38
+ run_id = run_path.stem
39
+ project = run_path.parent.stem
40
+ model_artifact_name = 'run_' + run_id + '_model'
41
+ return run_id, project, model_artifact_name
42
 
43
+ def check_wandb_resume(opt):
44
+ process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
45
  if isinstance(opt.resume, str):
46
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
47
+ if opt.global_rank not in [-1, 0]: # For resuming DDP runs
48
+ run_id, project, model_artifact_name = get_run_info(opt.resume)
49
+ api = wandb.Api()
50
+ artifact = api.artifact(project + '/' + model_artifact_name + ':latest')
51
+ modeldir = artifact.download()
52
+ opt.weights = str(Path(modeldir) / "last.pt")
53
+ return True
 
 
54
  return None
55
 
56
+ def process_wandb_config_ddp_mode(opt):
57
+ with open(opt.data) as f:
58
+ data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
59
+ train_dir, val_dir = None, None
60
+ if data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
61
+ api = wandb.Api()
62
+ train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
63
+ train_dir = train_artifact.download()
64
+ train_path = Path(train_dir) / 'data/images/'
65
+ data_dict['train'] = str(train_path)
66
+
67
+ if data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
68
+ api = wandb.Api()
69
+ val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
70
+ val_dir = val_artifact.download()
71
+ val_path = Path(val_dir) / 'data/images/'
72
+ data_dict['val'] = str(val_path)
73
+ if train_dir or val_dir:
74
+ ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
75
+ with open(ddp_data_path, 'w') as f:
76
+ yaml.dump(data_dict, f)
77
+ opt.data = ddp_data_path
78
+
79
+
80
 
81
  class WandbLogger():
82
  def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
83
  # Pre-training routine --
84
  self.job_type = job_type
85
  self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
86
+ # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
87
+ if isinstance(opt.resume, str): # checks resume from artifact
88
+ if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
89
+ run_id, project, model_artifact_name = get_run_info(opt.resume)
90
+ model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
91
+ assert wandb, 'install wandb to resume wandb runs'
92
+ # Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
93
+ self.wandb_run = wandb.init(id=run_id, project=project, resume='allow')
94
+ opt.resume = model_artifact_name
95
+ elif self.wandb:
96
  self.wandb_run = wandb.init(config=opt,
97
  resume="allow",
98
  project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
99
  name=name,
100
  job_type=job_type,
101
+ id=run_id) if not wandb.run else wandb.run
102
+ if self.wandb_run:
103
  if self.job_type == 'Training':
104
  if not opt.resume:
105
  wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict