glenn-jocher commited on
Commit
aad99b6
·
unverified ·
1 Parent(s): 407dc50

TensorBoard DP/DDP graph fix (#3325)

Browse files
Files changed (2) hide show
  1. train.py +3 -3
  2. utils/torch_utils.py +6 -0
train.py CHANGED
@@ -32,7 +32,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
32
  from utils.google_utils import attempt_download
33
  from utils.loss import ComputeLoss
34
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
35
- from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
36
  from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
37
 
38
  logger = logging.getLogger(__name__)
@@ -331,7 +331,7 @@ def train(hyp, opt, device, tb_writer=None):
331
  f = save_dir / f'train_batch{ni}.jpg' # filename
332
  Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
333
  if tb_writer:
334
- tb_writer.add_graph(torch.jit.trace(model, imgs, strict=False), []) # add model graph
335
  # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
336
  elif plots and ni == 10 and wandb_logger.wandb:
337
  wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
@@ -390,7 +390,7 @@ def train(hyp, opt, device, tb_writer=None):
390
  ckpt = {'epoch': epoch,
391
  'best_fitness': best_fitness,
392
  'training_results': results_file.read_text(),
393
- 'model': deepcopy(model.module if is_parallel(model) else model).half(),
394
  'ema': deepcopy(ema.ema).half(),
395
  'updates': ema.updates,
396
  'optimizer': optimizer.state_dict(),
 
32
  from utils.google_utils import attempt_download
33
  from utils.loss import ComputeLoss
34
  from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
35
+ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
36
  from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
37
 
38
  logger = logging.getLogger(__name__)
 
331
  f = save_dir / f'train_batch{ni}.jpg' # filename
332
  Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
333
  if tb_writer:
334
+ tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph
335
  # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
336
  elif plots and ni == 10 and wandb_logger.wandb:
337
  wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
 
390
  ckpt = {'epoch': epoch,
391
  'best_fitness': best_fitness,
392
  'training_results': results_file.read_text(),
393
+ 'model': deepcopy(de_parallel(model)).half(),
394
  'ema': deepcopy(ema.ema).half(),
395
  'updates': ema.updates,
396
  'optimizer': optimizer.state_dict(),
utils/torch_utils.py CHANGED
@@ -134,9 +134,15 @@ def profile(x, ops, n=100, device=None):
134
 
135
 
136
  def is_parallel(model):
 
137
  return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
138
 
139
 
 
 
 
 
 
140
  def intersect_dicts(da, db, exclude=()):
141
  # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
142
  return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
 
134
 
135
 
136
  def is_parallel(model):
137
+ # Returns True if model is of type DP or DDP
138
  return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
139
 
140
 
141
+ def de_parallel(model):
142
+ # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
143
+ return model.module if is_parallel(model) else model
144
+
145
+
146
  def intersect_dicts(da, db, exclude=()):
147
  # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
148
  return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}