bilzard commited on
Commit
e1dc894
1 Parent(s): d95978a

Enable AdamW optimizer (#6152)

Browse files
Files changed (1) hide show
  1. train.py +5 -3
train.py CHANGED
@@ -22,7 +22,7 @@ import torch.nn as nn
22
  import yaml
23
  from torch.cuda import amp
24
  from torch.nn.parallel import DistributedDataParallel as DDP
25
- from torch.optim import SGD, Adam, lr_scheduler
26
  from tqdm import tqdm
27
 
28
  FILE = Path(__file__).resolve()
@@ -155,8 +155,10 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
155
  elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
156
  g1.append(v.weight)
157
 
158
- if opt.adam:
159
  optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
 
 
160
  else:
161
  optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
162
 
@@ -460,7 +462,7 @@ def parse_opt(known=False):
460
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
461
  parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
462
  parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
463
- parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
464
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
465
  parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
466
  parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
 
22
  import yaml
23
  from torch.cuda import amp
24
  from torch.nn.parallel import DistributedDataParallel as DDP
25
+ from torch.optim import SGD, Adam, AdamW, lr_scheduler
26
  from tqdm import tqdm
27
 
28
  FILE = Path(__file__).resolve()
 
155
  elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
156
  g1.append(v.weight)
157
 
158
+ if opt.optimizer == 'Adam':
159
  optimizer = Adam(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
160
+ elif opt.optimizer == 'AdamW':
161
+ optimizer = AdamW(g0, lr=hyp['lr0'], betas=(hyp['momentum'], 0.999)) # adjust beta1 to momentum
162
  else:
163
  optimizer = SGD(g0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
164
 
 
462
  parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
463
  parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
464
  parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
465
+ parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
466
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
467
  parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
468
  parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')