fancyfeast
commited on
Commit
•
df9e86f
1
Parent(s):
6982e15
Bugfix
Browse files
Models.py
CHANGED
@@ -134,13 +134,13 @@ class VisionModel(nn.Module):
|
|
134 |
from safetensors.torch import load_file
|
135 |
resume = load_file(Path(path) / 'model.safetensors', device='cpu')
|
136 |
else:
|
137 |
-
resume = torch.load(Path(path) / 'model.pt', map_location=torch.device('cpu'))
|
138 |
|
139 |
model_classes = VisionModel.__subclasses__()
|
140 |
model_cls = next(cls for cls in model_classes if cls.__name__ == config['class'])
|
141 |
|
142 |
model = model_cls(**{k: v for k, v in config.items() if k != 'class'})
|
143 |
-
model.load(resume
|
144 |
if device is not None:
|
145 |
model = model.to(device)
|
146 |
|
|
|
134 |
from safetensors.torch import load_file
|
135 |
resume = load_file(Path(path) / 'model.safetensors', device='cpu')
|
136 |
else:
|
137 |
+
resume = torch.load(Path(path) / 'model.pt', map_location=torch.device('cpu'))['model']
|
138 |
|
139 |
model_classes = VisionModel.__subclasses__()
|
140 |
model_cls = next(cls for cls in model_classes if cls.__name__ == config['class'])
|
141 |
|
142 |
model = model_cls(**{k: v for k, v in config.items() if k != 'class'})
|
143 |
+
model.load(resume)
|
144 |
if device is not None:
|
145 |
model = model.to(device)
|
146 |
|