glenn-jocher
commited on
Merge remote-tracking branch 'origin/master'
Browse files
train.py
CHANGED
@@ -113,6 +113,12 @@ def train(hyp, tb_writer, opt, device):
|
|
113 |
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
|
114 |
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
|
115 |
del pg0, pg1, pg2
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
# Load Model
|
118 |
with torch_distributed_zero_first(rank):
|
@@ -158,12 +164,6 @@ def train(hyp, tb_writer, opt, device):
|
|
158 |
if mixed_precision:
|
159 |
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
160 |
|
161 |
-
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
162 |
-
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
|
163 |
-
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
164 |
-
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
165 |
-
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
166 |
-
|
167 |
# DP mode
|
168 |
if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
|
169 |
model = torch.nn.DataParallel(model)
|
|
|
113 |
optimizer.add_param_group({'params': pg2}) # add pg2 (biases)
|
114 |
print('Optimizer groups: %g .bias, %g conv.weight, %g other' % (len(pg2), len(pg1), len(pg0)))
|
115 |
del pg0, pg1, pg2
|
116 |
+
|
117 |
+
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
|
118 |
+
lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2 # cosine
|
119 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
|
120 |
+
# https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
|
121 |
+
# plot_lr_scheduler(optimizer, scheduler, epochs)
|
122 |
|
123 |
# Load Model
|
124 |
with torch_distributed_zero_first(rank):
|
|
|
164 |
if mixed_precision:
|
165 |
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
# DP mode
|
168 |
if device.type != 'cpu' and rank == -1 and torch.cuda.device_count() > 1:
|
169 |
model = torch.nn.DataParallel(model)
|