glenn-jocher commited on
Commit
14523bb
1 Parent(s): c5966ab

FP16 to FP32 ckpt load

Browse files
Files changed (1) hide show
  1. train.py +2 -2
train.py CHANGED
@@ -112,8 +112,8 @@ def train(hyp):
112
 
113
  # load model
114
  try:
115
- ckpt['model'] = \
116
- {k: v for k, v in ckpt['model'].state_dict().items() if model.state_dict()[k].numel() == v.numel()}
117
  model.load_state_dict(ckpt['model'], strict=False)
118
  except KeyError as e:
119
  s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
 
112
 
113
  # load model
114
  try:
115
+ ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
116
+ if model.state_dict()[k].shape == v.shape} # to FP32, filter
117
  model.load_state_dict(ckpt['model'], strict=False)
118
  except KeyError as e:
119
  s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \