PKUWilliamYang's picture
V1
983684c
raw history blame
No virus
4.51 kB
"""
This file defines the core research contribution
"""
import matplotlib
matplotlib.use('Agg')
import math
import torch
from torch import nn
from model.encoder.encoders import psp_encoders
from model.stylegan.model import Generator
def get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt
class pSp(nn.Module):
def __init__(self, opts):
super(pSp, self).__init__()
self.set_opts(opts)
# compute number of style inputs based on the output resolution
self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
# Define architecture
self.encoder = self.set_encoder()
self.decoder = Generator(self.opts.output_size, 512, 8)
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
# Load weights if needed
self.load_weights()
def set_encoder(self):
if self.opts.encoder_type == 'GradualStyleEncoder':
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
else:
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
return encoder
def load_weights(self):
if self.opts.checkpoint_path is not None:
print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
self.__load_latent_avg(ckpt)
else:
pass
'''print('Loading encoders weights from irse50!')
encoder_ckpt = torch.load(model_paths['ir_se50'])
# if input to encoder is not an RGB image, do not load the input layer weights
if self.opts.label_nc != 0:
encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
self.encoder.load_state_dict(encoder_ckpt, strict=False)
print('Loading decoder weights from pretrained!')
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
if self.opts.learn_in_w:
self.__load_latent_avg(ckpt, repeat=1)
else:
self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)
'''
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
inject_latent=None, return_latents=False, alpha=None, z_plus_latent=False, return_z_plus_latent=True):
if input_code:
codes = x
else:
codes = self.encoder(x)
#print(codes.shape)
# normalize with respect to the center of an average face
if self.opts.start_from_latent_avg:
if self.opts.learn_in_w:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
else:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
if latent_mask is not None:
for i in latent_mask:
if inject_latent is not None:
if alpha is not None:
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
else:
codes[:, i] = inject_latent[:, i]
else:
codes[:, i] = 0
input_is_latent = not input_code
if z_plus_latent:
input_is_latent = False
images, result_latent = self.decoder([codes],
input_is_latent=input_is_latent,
randomize_noise=randomize_noise,
return_latents=return_latents,
z_plus_latent=z_plus_latent)
if resize:
images = self.face_pool(images)
if return_latents:
if z_plus_latent and return_z_plus_latent:
return images, codes
if z_plus_latent and not return_z_plus_latent:
return images, result_latent
else:
return images, result_latent
else:
return images
def set_opts(self, opts):
self.opts = opts
def __load_latent_avg(self, ckpt, repeat=None):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
if repeat is not None:
self.latent_avg = self.latent_avg.repeat(repeat, 1)
else:
self.latent_avg = None