Update model_helper.py
Browse files- model_helper.py +1 -1
model_helper.py
CHANGED
@@ -110,7 +110,7 @@ def load_model_checkpoint(args=None):
|
|
110 |
eval_subtask_key=args.eval_subtask_key,
|
111 |
write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
|
112 |
).to(device)
|
113 |
-
checkpoint = torch.load(dir_info["last_ckpt_path"])
|
114 |
state_dict = checkpoint['state_dict']
|
115 |
new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
|
116 |
model.load_state_dict(new_state_dict, strict=False)
|
|
|
110 |
eval_subtask_key=args.eval_subtask_key,
|
111 |
write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None
|
112 |
).to(device)
|
113 |
+
checkpoint = torch.load(dir_info["last_ckpt_path"], map_location=device)
|
114 |
state_dict = checkpoint['state_dict']
|
115 |
new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k}
|
116 |
model.load_state_dict(new_state_dict, strict=False)
|