Gregniuki commited on
Commit
c0ed55a
1 Parent(s): b38f0ce

Update model/utils.py

Browse files
Files changed (1) hide show
  1. model/utils.py +1 -2
model/utils.py CHANGED
@@ -570,8 +570,7 @@ def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
570
 
571
  checkpoint = load_file(ckpt_path)
572
  else:
573
- checkpoint = torch.load(ckpt_path, weights_only=True)
574
-
575
  if use_ema:
576
  if ckpt_type == "safetensors":
577
  checkpoint = {"ema_model_state_dict": checkpoint}
 
570
 
571
  checkpoint = load_file(ckpt_path)
572
  else:
573
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=torch.device('cpu'))
 
574
  if use_ema:
575
  if ckpt_type == "safetensors":
576
  checkpoint = {"ema_model_state_dict": checkpoint}