Songyou commited on
Commit
d4c9e57
·
verified ·
1 Parent(s): f93cc1d

Update models/transformer/encode_decode/model.py

Browse files
models/transformer/encode_decode/model.py CHANGED
@@ -63,7 +63,7 @@ class EncoderDecoder(nn.Module):
63
  @classmethod
64
  def load_from_file(cls, file_path):
65
  # Load model
66
- checkpoint = torch.load(file_path, map_location='cuda:0')
67
  para_dict = checkpoint['model_parameters']
68
  vocab_size = para_dict['vocab_size']
69
  model = EncoderDecoder.make_model(vocab_size, vocab_size, para_dict['N'],
 
63
  @classmethod
64
  def load_from_file(cls, file_path):
65
  # Load model
66
+ checkpoint = torch.load(file_path, map_location='cpu')
67
  para_dict = checkpoint['model_parameters']
68
  vocab_size = para_dict['vocab_size']
69
  model = EncoderDecoder.make_model(vocab_size, vocab_size, para_dict['N'],