Spaces:
Runtime error
Runtime error
File size: 7,805 Bytes
92ec8d3 6fa3e0e 92ec8d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import math
import torch
from torch import nn
import copy
from argparse import Namespace
from models.hyperstyle.encoders.psp import pSp
from models.stylegan2.model import Generator
from models.hyperstyle.configs.paths_config import model_paths
from models.hyperstyle.hypernetworks.hypernetwork import SharedWeightsHyperNetResNet, SharedWeightsHyperNetResNetSeparable
from models.hyperstyle.utils.resnet_mapping import RESNET_MAPPING
class HyperStyle(nn.Module):
def __init__(self, opts):
super(HyperStyle, self).__init__()
self.set_opts(opts)
self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
# Define architecture
self.hypernet = self.set_hypernet()
self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2)
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
# Load weights if needed
self.load_weights()
if self.opts.load_w_encoder:
self.w_encoder.eval()
def set_hypernet(self):
if self.opts.output_size == 1024:
self.opts.n_hypernet_outputs = 26
elif self.opts.output_size == 512:
self.opts.n_hypernet_outputs = 23
elif self.opts.output_size == 256:
self.opts.n_hypernet_outputs = 20
else:
raise ValueError(f"Invalid Output Size! Support sizes: [1024, 512, 256]!")
networks = {
"SharedWeightsHyperNetResNet": SharedWeightsHyperNetResNet(opts=self.opts),
"SharedWeightsHyperNetResNetSeparable": SharedWeightsHyperNetResNetSeparable(opts=self.opts),
}
return networks[self.opts.encoder_type]
def load_weights(self):
if self.opts.checkpoint_path is not None:
print(f'Loading HyperStyle from checkpoint: {self.opts.checkpoint_path}')
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
self.hypernet.load_state_dict(self.__get_keys(ckpt, 'hypernet'), strict=True)
self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)
self.__load_latent_avg(ckpt)
if self.opts.load_w_encoder:
self.w_encoder = self.__get_pretrained_w_encoder()
else:
hypernet_ckpt = self.__get_hypernet_checkpoint()
self.hypernet.load_state_dict(hypernet_ckpt, strict=False)
print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
self.__load_latent_avg(ckpt, repeat=self.n_styles)
if self.opts.load_w_encoder:
self.w_encoder = self.__get_pretrained_w_encoder()
def forward(self, x, resize=True, input_code=False, randomize_noise=True, return_latents=False,
return_weight_deltas_and_codes=False, weights_deltas=None, y_hat=None, codes=None):
if input_code:
codes = x
else:
if y_hat is None:
assert self.opts.load_w_encoder, "Cannot infer latent code when e4e isn't loaded."
y_hat, codes = self.__get_initial_inversion(x, resize=True)
# concatenate original input with w-reconstruction or current reconstruction
x_input = torch.cat([x, y_hat], dim=1)
# pass through hypernet to get per-layer deltas
hypernet_outputs = self.hypernet(x_input)
if weights_deltas is None:
weights_deltas = hypernet_outputs
else:
weights_deltas = [weights_deltas[i] + hypernet_outputs[i] if weights_deltas[i] is not None else None
for i in range(len(hypernet_outputs))]
input_is_latent = (not input_code)
images, result_latent, _ = self.decoder([codes],
weights_deltas=weights_deltas,
input_is_latent=input_is_latent,
randomize_noise=randomize_noise,
return_latents=return_latents)
if resize:
images = self.face_pool(images)
if return_latents and return_weight_deltas_and_codes:
return images, result_latent, weights_deltas, codes, y_hat
elif return_latents:
return images, result_latent
elif return_weight_deltas_and_codes:
return images, weights_deltas, codes
else:
return images
def set_opts(self, opts):
self.opts = opts
def __load_latent_avg(self, ckpt, repeat=None):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
if repeat is not None:
self.latent_avg = self.latent_avg.repeat(repeat, 1)
else:
self.latent_avg = None
def __get_hypernet_checkpoint(self):
print('Loading hypernet weights from resnet34!')
hypernet_ckpt = torch.load(model_paths['resnet34'])
# Transfer the RGB input of the resnet34 network to the first 3 input channels of hypernet
if self.opts.input_nc != 3:
shape = hypernet_ckpt['conv1.weight'].shape
altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32)
altered_input_layer[:, :3, :, :] = hypernet_ckpt['conv1.weight']
hypernet_ckpt['conv1.weight'] = altered_input_layer
mapped_hypernet_ckpt = dict(hypernet_ckpt)
for p, v in hypernet_ckpt.items():
for original_name, net_name in RESNET_MAPPING.items():
if original_name in p:
mapped_hypernet_ckpt[p.replace(original_name, net_name)] = v
mapped_hypernet_ckpt.pop(p)
return hypernet_ckpt
@staticmethod
def __get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt
def __get_pretrained_w_encoder(self):
print("Loading pretrained W encoder...")
opts_w_encoder = vars(copy.deepcopy(self.opts))
opts_w_encoder['checkpoint_path'] = self.opts.w_encoder_checkpoint_path
opts_w_encoder['encoder_type'] = self.opts.w_encoder_type
opts_w_encoder['input_nc'] = 3
opts_w_encoder = Namespace(**opts_w_encoder)
w_net = pSp(opts_w_encoder)
w_net = w_net.encoder
w_net.eval()
w_net.to(self.opts.device)
return w_net
def __get_initial_inversion(self, x, resize=True):
# get initial inversion and reconstruction of batch
with torch.no_grad():
return self.__get_w_inversion(x, resize)
def __get_w_inversion(self, x, resize=True):
if self.w_encoder.training:
self.w_encoder.eval()
codes = self.w_encoder.forward(x)
if codes.ndim == 2:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
else:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
y_hat, _, _ = self.decoder([codes],
weights_deltas=None,
input_is_latent=True,
randomize_noise=False,
return_latents=False)
if resize:
y_hat = self.face_pool(y_hat)
if "cars" in self.opts.dataset_type:
y_hat = y_hat[:, :, 32:224, :]
return y_hat, codes
def w_invert(self, x, resize=True):
with torch.no_grad():
return self.__get_w_inversion(x, resize)
|