shethjenil commited on
Commit
5064db7
·
verified ·
1 Parent(s): feac94a

Update musc/model.py

Browse files
Files changed (1) hide show
  1. musc/model.py +2 -1
musc/model.py CHANGED
@@ -238,8 +238,9 @@ class PretrainedModel(FourHeads):
238
  self.download_weights(instrument)
239
  package_dir = os.path.dirname(os.path.realpath(__file__))
240
  filename = "{}_model.pt".format(instrument)
241
- self.load_state_dict(torch.load(os.path.join(package_dir, filename)))
242
 
 
243
  def download_weights(self, instrument):
244
  weight_file = "{}_model.pt".format(instrument)
245
  package_dir = os.path.dirname(os.path.realpath(__file__))
 
238
  self.download_weights(instrument)
239
  package_dir = os.path.dirname(os.path.realpath(__file__))
240
  filename = "{}_model.pt".format(instrument)
241
+ self.load_state_dict(torch.load(os.path.join(package_dir, filename), map_location=torch.device('cpu')))
242
 
243
+
244
  def download_weights(self, instrument):
245
  weight_file = "{}_model.pt".format(instrument)
246
  package_dir = os.path.dirname(os.path.realpath(__file__))