PKUWilliamYang commited on
Commit
d6e1852
1 Parent(s): 1039991

Update models/psp.py

Browse files
Files changed (1) hide show
  1. models/psp.py +5 -4
models/psp.py CHANGED
@@ -21,7 +21,7 @@ def get_keys(d, name):
21
 
22
  class pSp(nn.Module):
23
 
24
- def __init__(self, opts):
25
  super(pSp, self).__init__()
26
  self.set_opts(opts)
27
  # compute number of style inputs based on the output resolution
@@ -31,7 +31,7 @@ class pSp(nn.Module):
31
  self.decoder = Generator(self.opts.output_size, 512, 8)
32
  self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
33
  # Load weights if needed
34
- self.load_weights()
35
 
36
  def set_encoder(self):
37
  if self.opts.encoder_type == 'GradualStyleEncoder':
@@ -44,10 +44,11 @@ class pSp(nn.Module):
44
  raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
45
  return encoder
46
 
47
- def load_weights(self):
48
  if self.opts.checkpoint_path is not None:
49
  print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
50
- ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
 
51
  self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=False)
52
  self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=False)
53
  self.__load_latent_avg(ckpt)
21
 
22
  class pSp(nn.Module):
23
 
24
+ def __init__(self, opts, ckpt=None):
25
  super(pSp, self).__init__()
26
  self.set_opts(opts)
27
  # compute number of style inputs based on the output resolution
31
  self.decoder = Generator(self.opts.output_size, 512, 8)
32
  self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
33
  # Load weights if needed
34
+ self.load_weights(ckpt)
35
 
36
  def set_encoder(self):
37
  if self.opts.encoder_type == 'GradualStyleEncoder':
44
  raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
45
  return encoder
46
 
47
+ def load_weights(self, ckpt=None):
48
  if self.opts.checkpoint_path is not None:
49
  print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
50
+ if ckpt is None:
51
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
52
  self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=False)
53
  self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=False)
54
  self.__load_latent_avg(ckpt)