Yw22 commited on
Commit
b5bc9f7
1 Parent(s): be61342
Files changed (1) hide show
  1. utils/utils.py +1 -1
utils/utils.py CHANGED
@@ -229,7 +229,7 @@ def load_checkpoint(model_file, model):
229
  def load_model(model, model_path):
230
  if model_path != "":
231
  print(f"init model from checkpoint: {model_path}")
232
- model_ckpt = torch.load(model_path, map_location="cuda")
233
  if "global_step" in model_ckpt: print(f"global_step: {model_ckpt['global_step']}")
234
  state_dict = model_ckpt["state_dict"] if "state_dict" in model_ckpt else model_ckpt
235
  m, u = model.load_state_dict(state_dict, strict=False)
 
229
  def load_model(model, model_path):
230
  if model_path != "":
231
  print(f"init model from checkpoint: {model_path}")
232
+ model_ckpt = torch.load(model_path, map_location="cpu")
233
  if "global_step" in model_ckpt: print(f"global_step: {model_ckpt['global_step']}")
234
  state_dict = model_ckpt["state_dict"] if "state_dict" in model_ckpt else model_ckpt
235
  m, u = model.load_state_dict(state_dict, strict=False)