osimeoni commited on
Commit
1d605a5
1 Parent(s): 5c00c7e
Files changed (1) hide show
  1. model.py +2 -1
model.py CHANGED
@@ -156,7 +156,8 @@ class FoundModel(nn.Module):
156
  def decoder_load_weights(self, weights_path):
157
  print(f"Loading model from weights {weights_path}.")
158
  # Load states
159
- state_dict = torch.load(weights_path)
 
160
 
161
  # Decoder
162
  self.decoder.load_state_dict(state_dict["decoder"])
 
156
  def decoder_load_weights(self, weights_path):
157
  print(f"Loading model from weights {weights_path}.")
158
  # Load states
159
+ map_location=torch.device('cpu')
160
+ state_dict = torch.load(weights_path, map_location=map_location)
161
 
162
  # Decoder
163
  self.decoder.load_state_dict(state_dict["decoder"])