glenn-jocher commited on
Commit
1df494e
2 Parent(s): 2268f9c 4e2b9ec

Merge remote-tracking branch 'origin/master'

Browse files
Files changed (1) hide show
  1. train.py +6 -6
train.py CHANGED
@@ -113,6 +113,12 @@ def train(hyp, tb_writer, opt, device):
113
  optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
114
  print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
115
  del pg0, pg1, pg2
 
 
 
 
 
 
116
 
117
  # Load Model
118
  with torch_distributed_zero_first(rank):
@@ -158,12 +164,6 @@ def train(hyp, tb_writer, opt, device):
158
  if mixed_precision:
159
  model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
160
 
161
- # Scheduler https://arxiv.org/pdf/1812.01187.pdf
162
- lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
163
- scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
164
- # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
165
- # plot_lr_scheduler(optimizer, scheduler, epochs)
166
-
167
  # DP mode
168
  if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
169
  model = torch.nn.DataParallel(model)
 
113
  optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
114
  print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
115
  del pg0, pg1, pg2
116
+
117
+ # Scheduler https://arxiv.org/pdf/1812.01187.pdf
118
+ lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
119
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
120
+ # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
121
+ # plot_lr_scheduler(optimizer, scheduler, epochs)
122
 
123
  # Load Model
124
  with torch_distributed_zero_first(rank):
 
164
  if mixed_precision:
165
  model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
166
 
 
 
 
 
 
 
167
  # DP mode
168
  if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
169
  model = torch.nn.DataParallel(model)