Refactor optimizer initialization (#8607)
Browse files* Refactor optimizer initialization
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update train.py
* Update train.py
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- train.py +4 -25
- utils/torch_utils.py +31 -1
train.py
CHANGED
@@ -28,7 +28,7 @@ import torch.distributed as dist
|
|
28 |
import torch.nn as nn
|
29 |
import yaml
|
30 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
31 |
-
from torch.optim import
|
32 |
from tqdm import tqdm
|
33 |
|
34 |
FILE = Path(__file__).resolve()
|
@@ -54,7 +54,8 @@ from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
|
54 |
from utils.loss import ComputeLoss
|
55 |
from utils.metrics import fitness
|
56 |
from utils.plots import plot_evolve, plot_labels
|
57 |
-
from utils.torch_utils import EarlyStopping, ModelEMA, de_parallel, select_device,
|
|
|
58 |
|
59 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
60 |
RANK = int(os.getenv('RANK', -1))
|
@@ -149,29 +150,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
|
149 |
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
|
150 |
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
|
151 |
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
|
152 |
-
|
153 |
-
g = [], [], [] # optimizer parameter groups
|
154 |
-
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
155 |
-
for v in model.modules():
|
156 |
-
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias
|
157 |
-
g[2].append(v.bias)
|
158 |
-
if isinstance(v, bn): # weight (no decay)
|
159 |
-
g[1].append(v.weight)
|
160 |
-
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
161 |
-
g[0].append(v.weight)
|
162 |
-
|
163 |
-
if opt.optimizer == 'Adam':
|
164 |
-
optimizer = Adam(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
|
165 |
-
elif opt.optimizer == 'AdamW':
|
166 |
-
optimizer = AdamW(g[2], lr=hyp['lr0'], betas=(hyp['momentum'], 0.999), weight_decay=0.0)
|
167 |
-
else:
|
168 |
-
optimizer = SGD(g[2], lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
|
169 |
-
|
170 |
-
optimizer.add_param_group({'params': g[0], 'weight_decay': hyp['weight_decay']}) # add g0 with weight_decay
|
171 |
-
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
172 |
-
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
|
173 |
-
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
|
174 |
-
del g
|
175 |
|
176 |
# Scheduler
|
177 |
if opt.cos_lr:
|
|
|
28 |
import torch.nn as nn
|
29 |
import yaml
|
30 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
31 |
+
from torch.optim import lr_scheduler
|
32 |
from tqdm import tqdm
|
33 |
|
34 |
FILE = Path(__file__).resolve()
|
|
|
54 |
from utils.loss import ComputeLoss
|
55 |
from utils.metrics import fitness
|
56 |
from utils.plots import plot_evolve, plot_labels
|
57 |
+
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_optimizer,
|
58 |
+
torch_distributed_zero_first)
|
59 |
|
60 |
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
61 |
RANK = int(os.getenv('RANK', -1))
|
|
|
150 |
accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
|
151 |
hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
|
152 |
LOGGER.info(f"Scaled weight_decay = {hyp['weight_decay']}")
|
153 |
+
optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
# Scheduler
|
156 |
if opt.cos_lr:
|
utils/torch_utils.py
CHANGED
@@ -18,7 +18,7 @@ import torch.distributed as dist
|
|
18 |
import torch.nn as nn
|
19 |
import torch.nn.functional as F
|
20 |
|
21 |
-
from utils.general import LOGGER, file_date, git_describe
|
22 |
|
23 |
try:
|
24 |
import thop # for FLOPs computation
|
@@ -260,6 +260,36 @@ def copy_attr(a, b, include=(), exclude=()):
|
|
260 |
setattr(a, k, v)
|
261 |
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
class EarlyStopping:
|
264 |
# YOLOv5 simple early stopper
|
265 |
def __init__(self, patience=30):
|
|
|
18 |
import torch.nn as nn
|
19 |
import torch.nn.functional as F
|
20 |
|
21 |
+
from utils.general import LOGGER, colorstr, file_date, git_describe
|
22 |
|
23 |
try:
|
24 |
import thop # for FLOPs computation
|
|
|
260 |
setattr(a, k, v)
|
261 |
|
262 |
|
263 |
+
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e-5):
|
264 |
+
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
|
265 |
+
g = [], [], [] # optimizer parameter groups
|
266 |
+
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
267 |
+
for v in model.modules():
|
268 |
+
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
|
269 |
+
g[2].append(v.bias)
|
270 |
+
if isinstance(v, bn): # weight (no decay)
|
271 |
+
g[1].append(v.weight)
|
272 |
+
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
273 |
+
g[0].append(v.weight)
|
274 |
+
|
275 |
+
if name == 'Adam':
|
276 |
+
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
|
277 |
+
elif name == 'AdamW':
|
278 |
+
optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
279 |
+
elif name == 'RMSProp':
|
280 |
+
optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
281 |
+
elif name == 'SGD':
|
282 |
+
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
283 |
+
else:
|
284 |
+
raise NotImplementedError(f'Optimizer {name} not implemented.')
|
285 |
+
|
286 |
+
optimizer.add_param_group({'params': g[0], 'weight_decay': weight_decay}) # add g0 with weight_decay
|
287 |
+
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
288 |
+
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__} with parameter groups "
|
289 |
+
f"{len(g[1])} weight (no decay), {len(g[0])} weight, {len(g[2])} bias")
|
290 |
+
return optimizer
|
291 |
+
|
292 |
+
|
293 |
class EarlyStopping:
|
294 |
# YOLOv5 simple early stopper
|
295 |
def __init__(self, patience=30):
|