Spaces:
Runtime error
Runtime error
File size: 5,736 Bytes
92ec8d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
"""
This file defines the core research contribution
"""
import math
import torch
from torch import nn
from models.stylegan2.model import Generator
from models.hyperstyle.configs.paths_config import model_paths
from models.hyperstyle.encoders import restyle_e4e_encoders
from models.hyperstyle.utils.resnet_mapping import RESNET_MAPPING
class e4e(nn.Module):
def __init__(self, opts):
super(e4e, self).__init__()
self.set_opts(opts)
self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
# Define architecture
self.encoder = self.set_encoder()
self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2)
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
# Load weights if needed
self.load_weights()
def set_encoder(self):
if self.opts.encoder_type == 'ProgressiveBackboneEncoder':
encoder = restyle_e4e_encoders.ProgressiveBackboneEncoder(50, 'ir_se', self.n_styles, self.opts)
elif self.opts.encoder_type == 'ResNetProgressiveBackboneEncoder':
encoder = restyle_e4e_encoders.ResNetProgressiveBackboneEncoder(self.n_styles, self.opts)
else:
raise Exception(f'{self.opts.encoder_type} is not a valid encoders')
return encoder
def load_weights(self):
if self.opts.checkpoint_path is not None:
print(f'Loading ReStyle e4e from checkpoint: {self.opts.checkpoint_path}')
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
self.encoder.load_state_dict(self.__get_keys(ckpt, 'encoder'), strict=True)
self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)
self.__load_latent_avg(ckpt)
else:
encoder_ckpt = self.__get_encoder_checkpoint()
self.encoder.load_state_dict(encoder_ckpt, strict=False)
print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
self.__load_latent_avg(ckpt, repeat=self.n_styles)
def forward(self, x, latent=None, resize=True, input_code=False, randomize_noise=True,
return_latents=False, average_code=False, input_is_full=False):
if input_code:
codes = x
else:
codes = self.encoder(x)
# residual step
if x.shape[1] == 6 and latent is not None:
# learn error with respect to previous iteration
codes = codes + latent
else:
# first iteration is with respect to the avg latent code
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
if average_code:
input_is_latent = True
else:
input_is_latent = (not input_code) or (input_is_full)
images, result_latent = self.decoder([codes],
input_is_latent=input_is_latent,
randomize_noise=randomize_noise,
return_latents=return_latents)
if resize:
images = self.face_pool(images)
if return_latents:
return images, result_latent
else:
return images
def set_opts(self, opts):
self.opts = opts
def __load_latent_avg(self, ckpt, repeat=None):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
if repeat is not None:
self.latent_avg = self.latent_avg.repeat(repeat, 1)
else:
self.latent_avg = None
def __get_encoder_checkpoint(self):
if "ffhq" in self.opts.dataset_type:
print('Loading encoders weights from irse50!')
encoder_ckpt = torch.load(model_paths['ir_se50'])
# Transfer the RGB input of the irse50 network to the first 3 input channels of pSp's encoder
if self.opts.input_nc != 3:
shape = encoder_ckpt['input_layer.0.weight'].shape
altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32)
altered_input_layer[:, :3, :, :] = encoder_ckpt['input_layer.0.weight']
encoder_ckpt['input_layer.0.weight'] = altered_input_layer
return encoder_ckpt
else:
print('Loading encoders weights from resnet34!')
encoder_ckpt = torch.load(model_paths['resnet34'])
# Transfer the RGB input of the resnet34 network to the first 3 input channels of pSp's encoder
if self.opts.input_nc != 3:
shape = encoder_ckpt['conv1.weight'].shape
altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32)
altered_input_layer[:, :3, :, :] = encoder_ckpt['conv1.weight']
encoder_ckpt['conv1.weight'] = altered_input_layer
mapped_encoder_ckpt = dict(encoder_ckpt)
for p, v in encoder_ckpt.items():
for original_name, psp_name in RESNET_MAPPING.items():
if original_name in p:
mapped_encoder_ckpt[p.replace(original_name, psp_name)] = v
mapped_encoder_ckpt.pop(p)
return encoder_ckpt
@staticmethod
def __get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt
|