glenn-jocher pre-commit-ci[bot] commited on
Commit
51fb467
·
unverified ·
1 Parent(s): 2430578

Refactor optimizer initialization (#8607)

Browse files

* Refactor optimizer initialization

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

* Update train.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (2) hide show
  1. train.py +4 -25
  2. utils/torch_utils.py +31 -1
train.py CHANGED
@@ -28,7 +28,7 @@ import torch.distributed as dist
28
  import torch.nn as nn
29
  import yaml
30
  from torch.nn.parallel import DistributedDataParallel as DDP
31
- from torch.optim import SGD, Adam, AdamW, lr_scheduler
32
  from tqdm import tqdm
33
 
34
  FILE = Path(__file__).resolve()
@@ -54,7 +54,8 @@ from utils.loggers.wandb.wandb_utils import check_wandb_resume
54
  from utils.loss import ComputeLoss
55
  from utils.metrics import fitness
56
  from utils.plots import plot_evolve, plot_labels
57
- from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device, torch_distributed_zero_first
 
58
 
59
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
60
  RANK = int(os.getenv('RANK', -1))
@@ -149,29 +150,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
149
  accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
150
  hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
151
  LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
152
-
153
- g = [], [], [] # optimizer parameter groups
154
- bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
155
- for v in model.modules():
156
- if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
157
- g[2].append(v.bias)
158
- if isinstance(v, bn): # weight (no decay)
159
- g[1].append(v.weight)
160
- elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
161
- g[0].append(v.weight)
162
-
163
- if opt.optimizer == 'Adam':
164
- optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
165
- elif opt.optimizer == 'AdamW':
166
- optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999), weight_decay=0.0)
167
- else:
168
- optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
169
-
170
- optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
171
- optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
172
- LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
173
- f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
174
- del g
175
 
176
  # Scheduler
177
  if opt.cos_lr:
 
28
  import torch.nn as nn
29
  import yaml
30
  from torch.nn.parallel import DistributedDataParallel as DDP
31
+ from torch.optim import lr_scheduler
32
  from tqdm import tqdm
33
 
34
  FILE = Path(__file__).resolve()
 
54
  from utils.loss import ComputeLoss
55
  from utils.metrics import fitness
56
  from utils.plots import plot_evolve, plot_labels
57
+ from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer,
58
+ torch_distributed_zero_first)
59
 
60
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
61
  RANK = int(os.getenv('RANK', -1))
 
150
  accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
151
  hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
152
  LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
153
+ optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  # Scheduler
156
  if opt.cos_lr:
utils/torch_utils.py CHANGED
@@ -18,7 +18,7 @@ import torch.distributed as dist
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
 
21
- from utils.general import LOGGER, file_date, git_describe
22
 
23
  try:
24
  import thop # for FLOPs computation
@@ -260,6 +260,36 @@ def copy_attr(a, b, include=(), exclude=()):
260
  setattr(a, k, v)
261
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  class EarlyStopping:
264
  # YOLOv5 simple early stopper
265
  def __init__(self, patience=30):
 
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
 
21
+ from utils.general import LOGGER, colorstr, file_date, git_describe
22
 
23
  try:
24
  import thop # for FLOPs computation
 
260
  setattr(a, k, v)
261
 
262
 
263
+ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-5):
264
+ # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
265
+ g = [], [], [] # optimizer parameter groups
266
+ bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
267
+ for v in model.modules():
268
+ if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
269
+ g[2].append(v.bias)
270
+ if isinstance(v, bn): # weight (no decay)
271
+ g[1].append(v.weight)
272
+ elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
273
+ g[0].append(v.weight)
274
+
275
+ if name == 'Adam':
276
+ optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
277
+ elif name == 'AdamW':
278
+ optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
279
+ elif name == 'RMSProp':
280
+ optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
281
+ elif name == 'SGD':
282
+ optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
283
+ else:
284
+ raise NotImplementedError(f'Optimizer {name} not implemented.')
285
+
286
+ optimizer.add_param_group({'params': g[0], 'weight_decay': weight_decay}) # add g0 with weight_decay
287
+ optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
288
+ LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
289
+ f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
290
+ return optimizer
291
+
292
+
293
  class EarlyStopping:
294
  # YOLOv5 simple early stopper
295
  def __init__(self, patience=30):