glenn-jocher commited on
Commit
e931b9d
1 Parent(s): a3ecf0f

Resume with custom anchors fix (#2361)

Browse files

* Resume with custom anchors fix

* Update train.py

Files changed (1) hide show
  1. train.py +3 -4
train.py CHANGED
@@ -75,10 +75,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
75
  with torch_distributed_zero_first(rank):
76
  attempt_download(weights) # download if not found locally
77
  ckpt = torch.load(weights, map_location=device) # load checkpoint
78
- if hyp.get('anchors'):
79
- ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor
80
- model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
81
- exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
82
  state_dict = ckpt['model'].float().state_dict() # to FP32
83
  state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
84
  model.load_state_dict(state_dict, strict=False) # load
@@ -216,6 +214,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
216
  # Anchors
217
  if not opt.noautoanchor:
218
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
 
219
 
220
  # Model parameters
221
  hyp['box'] *= 3. / nl # scale to layers
 
75
  with torch_distributed_zero_first(rank):
76
  attempt_download(weights) # download if not found locally
77
  ckpt = torch.load(weights, map_location=device) # load checkpoint
78
+ model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
79
+ exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
 
 
80
  state_dict = ckpt['model'].float().state_dict() # to FP32
81
  state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
82
  model.load_state_dict(state_dict, strict=False) # load
 
214
  # Anchors
215
  if not opt.noautoanchor:
216
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
217
+ model.half().float() # pre-reduce anchor precision
218
 
219
  # Model parameters
220
  hyp['box'] *= 3. / nl # scale to layers