mimbres commited on
Commit
630576a
1 Parent(s): 43abcd8

Update model_helper.py

Browse files
Files changed (1) hide show
  1. model_helper.py +1 -1
model_helper.py CHANGED
@@ -110,7 +110,7 @@ def load_model_checkpoint(args=None):
110
  eval_subtask_key=args.eval_subtask_key,
111
  write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
112
  ).to(device)
113
- checkpoint = torch.load(dir_info["last_ckpt_path"])
114
  state_dict = checkpoint['state_dict']
115
  new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
116
  model.load_state_dict(new_state_dict, strict=False)
 
110
  eval_subtask_key=args.eval_subtask_key,
111
  write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
112
  ).to(device)
113
+ checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device)
114
  state_dict = checkpoint['state_dict']
115
  new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
116
  model.load_state_dict(new_state_dict, strict=False)