Rongjiehuang commited on
Commit
a85ad15
1 Parent(s): fc3a39d
Files changed (1) hide show
  1. data_gen/tts/emotion/inference.py +1 -1
data_gen/tts/emotion/inference.py CHANGED
@@ -30,7 +30,7 @@ def load_model(weights_fpath: Path, device=None):
30
  elif isinstance(device, str):
31
  _device = torch.device(device)
32
  _model = EmotionEncoder(_device, torch.device("cpu"))
33
- checkpoint = torch.load(weights_fpath)
34
  _model.load_state_dict(checkpoint["model_state"])
35
  _model.eval()
36
  print("Loaded encoder trained to step %d" % (checkpoint["step"]))
 
30
  elif isinstance(device, str):
31
  _device = torch.device(device)
32
  _model = EmotionEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
34
  _model.load_state_dict(checkpoint["model_state"])
35
  _model.eval()
36
  print("Loaded encoder trained to step %d" % (checkpoint["step"]))