Ayush Chaurasia
commited on
W&B resume ddp from run link fix (#2579)
Browse files* W&B resume ddp from run link fix
* Native DDP W&B support for training, resuming
- train.py +2 -2
- utils/wandb_logging/wandb_utils.py +52 -14
train.py
CHANGED
@@ -33,7 +33,7 @@ from utils.google_utils import attempt_download
|
|
33 |
from utils.loss import ComputeLoss
|
34 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
35 |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
|
36 |
-
from utils.wandb_logging.wandb_utils import WandbLogger,
|
37 |
|
38 |
logger = logging.getLogger(__name__)
|
39 |
|
@@ -496,7 +496,7 @@ if __name__ == '__main__':
|
|
496 |
check_requirements()
|
497 |
|
498 |
# Resume
|
499 |
-
wandb_run =
|
500 |
if opt.resume and not wandb_run: # resume an interrupted run
|
501 |
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
|
502 |
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
|
|
33 |
from utils.loss import ComputeLoss
|
34 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
35 |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
|
36 |
+
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
|
37 |
|
38 |
logger = logging.getLogger(__name__)
|
39 |
|
|
|
496 |
check_requirements()
|
497 |
|
498 |
# Resume
|
499 |
+
wandb_run = check_wandb_resume(opt)
|
500 |
if opt.resume and not wandb_run: # resume an interrupted run
|
501 |
ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
|
502 |
assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
|
utils/wandb_logging/wandb_utils.py
CHANGED
@@ -23,7 +23,7 @@ except ImportError:
|
|
23 |
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
|
24 |
|
25 |
|
26 |
-
def remove_prefix(from_string, prefix):
|
27 |
return from_string[len(prefix):]
|
28 |
|
29 |
|
@@ -33,35 +33,73 @@ def check_wandb_config_file(data_config_file):
|
|
33 |
return wandb_config
|
34 |
return data_config_file
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
def
|
38 |
-
|
39 |
if isinstance(opt.resume, str):
|
40 |
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
opt.resume = model_artifact_name
|
49 |
-
return run
|
50 |
return None
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
class WandbLogger():
|
54 |
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
|
55 |
# Pre-training routine --
|
56 |
self.job_type = job_type
|
57 |
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
self.wandb_run = wandb.init(config=opt,
|
60 |
resume="allow",
|
61 |
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
|
62 |
name=name,
|
63 |
job_type=job_type,
|
64 |
-
id=run_id) if not wandb.run else wandb.run
|
|
|
65 |
if self.job_type == 'Training':
|
66 |
if not opt.resume:
|
67 |
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
|
|
|
23 |
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
|
24 |
|
25 |
|
26 |
+
def remove_prefix(from_string, prefix=WANDB_ARTIFACT_PREFIX):
|
27 |
return from_string[len(prefix):]
|
28 |
|
29 |
|
|
|
33 |
return wandb_config
|
34 |
return data_config_file
|
35 |
|
36 |
+
def get_run_info(run_path):
|
37 |
+
run_path = Path(remove_prefix(run_path, WANDB_ARTIFACT_PREFIX))
|
38 |
+
run_id = run_path.stem
|
39 |
+
project = run_path.parent.stem
|
40 |
+
model_artifact_name = 'run_' + run_id + '_model'
|
41 |
+
return run_id, project, model_artifact_name
|
42 |
|
43 |
+
def check_wandb_resume(opt):
|
44 |
+
process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
|
45 |
if isinstance(opt.resume, str):
|
46 |
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
47 |
+
if opt.global_rank not in [-1, 0]: # For resuming DDP runs
|
48 |
+
run_id, project, model_artifact_name = get_run_info(opt.resume)
|
49 |
+
api = wandb.Api()
|
50 |
+
artifact = api.artifact(project + '/' + model_artifact_name + ':latest')
|
51 |
+
modeldir = artifact.download()
|
52 |
+
opt.weights = str(Path(modeldir) / "last.pt")
|
53 |
+
return True
|
|
|
|
|
54 |
return None
|
55 |
|
56 |
+
def process_wandb_config_ddp_mode(opt):
|
57 |
+
with open(opt.data) as f:
|
58 |
+
data_dict = yaml.load(f, Loader=yaml.SafeLoader) # data dict
|
59 |
+
train_dir, val_dir = None, None
|
60 |
+
if data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX):
|
61 |
+
api = wandb.Api()
|
62 |
+
train_artifact = api.artifact(remove_prefix(data_dict['train']) + ':' + opt.artifact_alias)
|
63 |
+
train_dir = train_artifact.download()
|
64 |
+
train_path = Path(train_dir) / 'data/images/'
|
65 |
+
data_dict['train'] = str(train_path)
|
66 |
+
|
67 |
+
if data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX):
|
68 |
+
api = wandb.Api()
|
69 |
+
val_artifact = api.artifact(remove_prefix(data_dict['val']) + ':' + opt.artifact_alias)
|
70 |
+
val_dir = val_artifact.download()
|
71 |
+
val_path = Path(val_dir) / 'data/images/'
|
72 |
+
data_dict['val'] = str(val_path)
|
73 |
+
if train_dir or val_dir:
|
74 |
+
ddp_data_path = str(Path(val_dir) / 'wandb_local_data.yaml')
|
75 |
+
with open(ddp_data_path, 'w') as f:
|
76 |
+
yaml.dump(data_dict, f)
|
77 |
+
opt.data = ddp_data_path
|
78 |
+
|
79 |
+
|
80 |
|
81 |
class WandbLogger():
|
82 |
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
|
83 |
# Pre-training routine --
|
84 |
self.job_type = job_type
|
85 |
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
|
86 |
+
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
|
87 |
+
if isinstance(opt.resume, str): # checks resume from artifact
|
88 |
+
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
|
89 |
+
run_id, project, model_artifact_name = get_run_info(opt.resume)
|
90 |
+
model_artifact_name = WANDB_ARTIFACT_PREFIX + model_artifact_name
|
91 |
+
assert wandb, 'install wandb to resume wandb runs'
|
92 |
+
# Resume wandb-artifact:// runs here| workaround for not overwriting wandb.config
|
93 |
+
self.wandb_run = wandb.init(id=run_id, project=project, resume='allow')
|
94 |
+
opt.resume = model_artifact_name
|
95 |
+
elif self.wandb:
|
96 |
self.wandb_run = wandb.init(config=opt,
|
97 |
resume="allow",
|
98 |
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
|
99 |
name=name,
|
100 |
job_type=job_type,
|
101 |
+
id=run_id) if not wandb.run else wandb.run
|
102 |
+
if self.wandb_run:
|
103 |
if self.job_type == 'Training':
|
104 |
if not opt.resume:
|
105 |
wandb_data_dict = self.check_and_upload_dataset(opt) if opt.upload_dataset else data_dict
|