SeanWei2 commited on
Commit
ab03c65
1 Parent(s): 787c919

Update models/experimental.py

Browse files
Files changed (1) hide show
  1. models/experimental.py +7 -6
models/experimental.py CHANGED
@@ -251,11 +251,12 @@ def attempt_load(weights, map_location=None):
251
  attempt_download(w)
252
  ckpt = torch.load(w, map_location=map_location) # load
253
 
254
- # 处理没有 'ema' 键的情况
255
- state_dict = ckpt['model'] if 'model' in ckpt else ckpt
256
-
257
- # 将 state_dict 作为模型参数
258
- model.append(state_dict.float().eval()) # 不使用 .fuse()
 
259
 
260
  # Compatibility updates
261
  for m in model.modules():
@@ -272,4 +273,4 @@ def attempt_load(weights, map_location=None):
272
  print('Ensemble created with %s\n' % weights)
273
  for k in ['names', 'stride']:
274
  setattr(model, k, getattr(model[-1], k))
275
- return model # return ensemble
 
251
  attempt_download(w)
252
  ckpt = torch.load(w, map_location=map_location) # load
253
 
254
+ if isinstance(ckpt, dict):
255
+ state_dict = ckpt['model'] if 'model' in ckpt else ckpt
256
+ model.append(state_dict.float().eval()) # 不使用 .fuse()
257
+ else:
258
+ # 如果 ckpt 是 GANLearner 实例,直接添加到模型中
259
+ model.append(ckpt)
260
 
261
  # Compatibility updates
262
  for m in model.modules():
 
273
  print('Ensemble created with %s\n' % weights)
274
  for k in ['names', 'stride']:
275
  setattr(model, k, getattr(model[-1], k))
276
+ return model # return ensemble