Update model/utils.py
Browse files- model/utils.py +1 -2
model/utils.py
CHANGED
@@ -570,8 +570,7 @@ def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
|
|
570 |
|
571 |
checkpoint = load_file(ckpt_path)
|
572 |
else:
|
573 |
-
checkpoint = torch.load(ckpt_path, weights_only=True)
|
574 |
-
|
575 |
if use_ema:
|
576 |
if ckpt_type == "safetensors":
|
577 |
checkpoint = {"ema_model_state_dict": checkpoint}
|
|
|
570 |
|
571 |
checkpoint = load_file(ckpt_path)
|
572 |
else:
|
573 |
+
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=torch.device('cpu'))
|
|
|
574 |
if use_ema:
|
575 |
if ckpt_type == "safetensors":
|
576 |
checkpoint = {"ema_model_state_dict": checkpoint}
|