TensorBoard DP/DDP graph fix (#3325)
Browse files- train.py +3 -3
- utils/torch_utils.py +6 -0
train.py
CHANGED
@@ -32,7 +32,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
|
|
32 |
from utils.google_utils import attempt_download
|
33 |
from utils.loss import ComputeLoss
|
34 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
35 |
-
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first,
|
36 |
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
|
37 |
|
38 |
logger = logging.getLogger(__name__)
|
@@ -331,7 +331,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
331 |
f = save_dir / f'train_batch{ni}.jpg' # filename
|
332 |
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
333 |
if tb_writer:
|
334 |
-
tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) #
|
335 |
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
336 |
elif plots and ni == 10 and wandb_logger.wandb:
|
337 |
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
|
@@ -390,7 +390,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
390 |
ckpt = {'epoch': epoch,
|
391 |
'best_fitness': best_fitness,
|
392 |
'training_results': results_file.read_text(),
|
393 |
-
'model': deepcopy(
|
394 |
'ema': deepcopy(ema.ema).half(),
|
395 |
'updates': ema.updates,
|
396 |
'optimizer': optimizer.state_dict(),
|
|
|
32 |
from utils.google_utils import attempt_download
|
33 |
from utils.loss import ComputeLoss
|
34 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
35 |
+
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
|
36 |
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
|
37 |
|
38 |
logger = logging.getLogger(__name__)
|
|
|
331 |
f = save_dir / f'train_batch{ni}.jpg' # filename
|
332 |
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
|
333 |
if tb_writer:
|
334 |
+
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph
|
335 |
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
|
336 |
elif plots and ni == 10 and wandb_logger.wandb:
|
337 |
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
|
|
|
390 |
ckpt = {'epoch': epoch,
|
391 |
'best_fitness': best_fitness,
|
392 |
'training_results': results_file.read_text(),
|
393 |
+
'model': deepcopy(de_parallel(model)).half(),
|
394 |
'ema': deepcopy(ema.ema).half(),
|
395 |
'updates': ema.updates,
|
396 |
'optimizer': optimizer.state_dict(),
|
utils/torch_utils.py
CHANGED
@@ -134,9 +134,15 @@ def profile(x, ops, n=100, device=None):
|
|
134 |
|
135 |
|
136 |
def is_parallel(model):
|
|
|
137 |
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
138 |
|
139 |
|
|
|
|
|
|
|
|
|
|
|
140 |
def intersect_dicts(da, db, exclude=()):
|
141 |
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
142 |
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
|
|
|
134 |
|
135 |
|
136 |
def is_parallel(model):
|
137 |
+
# Returns True if model is of type DP or DDP
|
138 |
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
139 |
|
140 |
|
141 |
+
def de_parallel(model):
|
142 |
+
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
143 |
+
return model.module if is_parallel(model) else model
|
144 |
+
|
145 |
+
|
146 |
def intersect_dicts(da, db, exclude=()):
|
147 |
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
148 |
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
|