Spaces:
Sleeping
Sleeping
Update models/experimental.py
Browse files- models/experimental.py +7 -6
models/experimental.py
CHANGED
@@ -251,11 +251,12 @@ def attempt_load(weights, map_location=None):
|
|
251 |
attempt_download(w)
|
252 |
ckpt = torch.load(w, map_location=map_location) # load
|
253 |
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
|
|
259 |
|
260 |
# Compatibility updates
|
261 |
for m in model.modules():
|
@@ -272,4 +273,4 @@ def attempt_load(weights, map_location=None):
|
|
272 |
print('Ensemble created with %s\n' % weights)
|
273 |
for k in ['names', 'stride']:
|
274 |
setattr(model, k, getattr(model[-1], k))
|
275 |
-
return model # return ensemble
|
|
|
251 |
attempt_download(w)
|
252 |
ckpt = torch.load(w, map_location=map_location) # load
|
253 |
|
254 |
+
if isinstance(ckpt, dict):
|
255 |
+
state_dict = ckpt['model'] if 'model' in ckpt else ckpt
|
256 |
+
model.append(state_dict.float().eval()) # 不使用 .fuse()
|
257 |
+
else:
|
258 |
+
# 如果 ckpt 是 GANLearner 实例,直接添加到模型中
|
259 |
+
model.append(ckpt)
|
260 |
|
261 |
# Compatibility updates
|
262 |
for m in model.modules():
|
|
|
273 |
print('Ensemble created with %s\n' % weights)
|
274 |
for k in ['names', 'stride']:
|
275 |
setattr(model, k, getattr(model[-1], k))
|
276 |
+
return model # return ensemble
|