mjdolan commited on
Commit
6957cac
1 Parent(s): 4a57b56

Update e4e/models/psp.py

Browse files
Files changed (1) hide show
  1. e4e/models/psp.py +3 -1
e4e/models/psp.py CHANGED
@@ -40,9 +40,11 @@ class pSp(nn.Module):
40
  def load_weights(self):
41
  if self.opts.checkpoint_path is not None:
42
  print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
43
- ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
44
  self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
 
45
  self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
 
46
  self.__load_latent_avg(ckpt)
47
  else:
48
  print('Loading encoders weights from irse50!')
 
40
  def load_weights(self):
41
  if self.opts.checkpoint_path is not None:
42
  print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
43
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cuda:0' if torch.cuda.is_available() else "cpu")
44
  self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
45
+ self.encoder.to(self.device)
46
  self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
47
+ self.decoder.to(self.device)
48
  self.__load_latent_avg(ckpt)
49
  else:
50
  print('Loading encoders weights from irse50!')