Janne Hellsten commited on
Commit
f7e4867
1 Parent(s): d3a616a

Add --allow-tf32 perf tuning argument that can be used to enable tf32

Browse files

Defaults to keeping tf32 disabled. This is because we haven't fully
verified training results with fp32 enabled.

Files changed (3) hide show
  1. docs/train-help.txt +1 -0
  2. train.py +8 -0
  3. training/training_loop.py +3 -0
docs/train-help.txt CHANGED
@@ -65,5 +65,6 @@ Options:
65
  --fp32 BOOL Disable mixed-precision training
66
  --nhwc BOOL Use NHWC memory format with FP16
67
  --nobench BOOL Disable cuDNN benchmarking
 
68
  --workers INT Override number of DataLoader workers
69
  --help Show this message and exit.
 
65
  --fp32 BOOL Disable mixed-precision training
66
  --nhwc BOOL Use NHWC memory format with FP16
67
  --nobench BOOL Disable cuDNN benchmarking
68
+ --allow-tf32 BOOL Allow PyTorch to use TF32 internally
69
  --workers INT Override number of DataLoader workers
70
  --help Show this message and exit.
train.py CHANGED
@@ -61,6 +61,7 @@ def setup_training_loop_kwargs(
61
  # Performance options (not included in desc).
62
  fp32 = None, # Disable mixed-precision training: <bool>, default = False
63
  nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False
 
64
  nobench = None, # Disable cuDNN benchmarking: <bool>, default = False
65
  workers = None, # Override number of DataLoader workers: <int>, default = 3
66
  ):
@@ -343,6 +344,12 @@ def setup_training_loop_kwargs(
343
  if nobench:
344
  args.cudnn_benchmark = False
345
 
 
 
 
 
 
 
346
  if workers is not None:
347
  assert isinstance(workers, int)
348
  if not workers >= 1:
@@ -425,6 +432,7 @@ class CommaSeparatedList(click.ParamType):
425
  @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
426
  @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
427
  @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
 
428
  @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
429
 
430
  def main(ctx, outdir, dry_run, **config_kwargs):
 
61
  # Performance options (not included in desc).
62
  fp32 = None, # Disable mixed-precision training: <bool>, default = False
63
  nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False
64
+ allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
65
  nobench = None, # Disable cuDNN benchmarking: <bool>, default = False
66
  workers = None, # Override number of DataLoader workers: <int>, default = 3
67
  ):
 
344
  if nobench:
345
  args.cudnn_benchmark = False
346
 
347
+ if allow_tf32 is None:
348
+ allow_tf32 = False
349
+ assert isinstance(allow_tf32, bool)
350
+ if allow_tf32:
351
+ args.allow_tf32 = True
352
+
353
  if workers is not None:
354
  assert isinstance(workers, int)
355
  if not workers >= 1:
 
432
  @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
433
  @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
434
  @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
435
+ @click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
436
  @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
437
 
438
  def main(ctx, outdir, dry_run, **config_kwargs):
training/training_loop.py CHANGED
@@ -115,6 +115,7 @@ def training_loop(
115
  network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
116
  resume_pkl = None, # Network pickle to resume training from.
117
  cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
 
118
  abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
119
  progress_fn = None, # Callback function for updating training progress. Called for all ranks.
120
  ):
@@ -124,6 +125,8 @@ def training_loop(
124
  np.random.seed(random_seed * num_gpus + rank)
125
  torch.manual_seed(random_seed * num_gpus + rank)
126
  torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
 
 
127
  conv2d_gradfix.enabled = True # Improves training speed.
128
  grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
129
 
 
115
  network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
116
  resume_pkl = None, # Network pickle to resume training from.
117
  cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
118
+ allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
119
  abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
120
  progress_fn = None, # Callback function for updating training progress. Called for all ranks.
121
  ):
 
125
  np.random.seed(random_seed * num_gpus + rank)
126
  torch.manual_seed(random_seed * num_gpus + rank)
127
  torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
128
+ torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
129
+ torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
130
  conv2d_gradfix.enabled = True # Improves training speed.
131
  grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
132