glenn-jocher commited on
Commit
5e970d4
1 Parent(s): 999804f

Update train.py (#462)

Browse files
Files changed (1) hide show
  1. train.py +4 -1
train.py CHANGED
@@ -123,9 +123,12 @@ def train(hyp, tb_writer, opt, device):
123
 
124
  # load model
125
  try:
 
126
  ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
127
- if k in model.state_dict() and model.state_dict()[k].shape == v.shape}
 
128
  model.load_state_dict(ckpt['model'], strict=False)
 
129
  except KeyError as e:
130
  s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
131
  "Please delete or update %s and try again, or use --weights '' to train from scratch." \
 
123
 
124
  # load model
125
  try:
126
+ exclude = ['anchor'] # exclude keys
127
  ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
128
+ if k in model.state_dict() and not any(x in k for x in exclude)
129
+ and model.state_dict()[k].shape == v.shape}
130
  model.load_state_dict(ckpt['model'], strict=False)
131
+ print('Transferred %g/%g items from %s' % (len(ckpt['model']), len(model.state_dict()), weights))
132
  except KeyError as e:
133
  s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
134
  "Please delete or update %s and try again, or use --weights '' to train from scratch." \