ethanNeuralImage's picture
fix GPU usage to be optional
6fa3e0e
raw
history blame
No virus
7.81 kB
import math
import torch
from torch import nn
import copy
from argparse import Namespace
from models.hyperstyle.encoders.psp import pSp
from models.stylegan2.model import Generator
from models.hyperstyle.configs.paths_config import model_paths
from models.hyperstyle.hypernetworks.hypernetwork import SharedWeightsHyperNetResNet, SharedWeightsHyperNetResNetSeparable
from models.hyperstyle.utils.resnet_mapping import RESNET_MAPPING
class HyperStyle(nn.Module):
def __init__(self, opts):
super(HyperStyle, self).__init__()
self.set_opts(opts)
self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
# Define architecture
self.hypernet = self.set_hypernet()
self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2)
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
# Load weights if needed
self.load_weights()
if self.opts.load_w_encoder:
self.w_encoder.eval()
def set_hypernet(self):
if self.opts.output_size == 1024:
self.opts.n_hypernet_outputs = 26
elif self.opts.output_size == 512:
self.opts.n_hypernet_outputs = 23
elif self.opts.output_size == 256:
self.opts.n_hypernet_outputs = 20
else:
raise ValueError(f"Invalid Output Size! Support sizes: [1024, 512, 256]!")
networks = {
"SharedWeightsHyperNetResNet": SharedWeightsHyperNetResNet(opts=self.opts),
"SharedWeightsHyperNetResNetSeparable": SharedWeightsHyperNetResNetSeparable(opts=self.opts),
}
return networks[self.opts.encoder_type]
def load_weights(self):
if self.opts.checkpoint_path is not None:
print(f'Loading HyperStyle from checkpoint: {self.opts.checkpoint_path}')
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
self.hypernet.load_state_dict(self.__get_keys(ckpt, 'hypernet'), strict=True)
self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)
self.__load_latent_avg(ckpt)
if self.opts.load_w_encoder:
self.w_encoder = self.__get_pretrained_w_encoder()
else:
hypernet_ckpt = self.__get_hypernet_checkpoint()
self.hypernet.load_state_dict(hypernet_ckpt, strict=False)
print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
ckpt = torch.load(self.opts.stylegan_weights)
self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
self.__load_latent_avg(ckpt, repeat=self.n_styles)
if self.opts.load_w_encoder:
self.w_encoder = self.__get_pretrained_w_encoder()
def forward(self, x, resize=True, input_code=False, randomize_noise=True, return_latents=False,
return_weight_deltas_and_codes=False, weights_deltas=None, y_hat=None, codes=None):
if input_code:
codes = x
else:
if y_hat is None:
assert self.opts.load_w_encoder, "Cannot infer latent code when e4e isn't loaded."
y_hat, codes = self.__get_initial_inversion(x, resize=True)
# concatenate original input with w-reconstruction or current reconstruction
x_input = torch.cat([x, y_hat], dim=1)
# pass through hypernet to get per-layer deltas
hypernet_outputs = self.hypernet(x_input)
if weights_deltas is None:
weights_deltas = hypernet_outputs
else:
weights_deltas = [weights_deltas[i] + hypernet_outputs[i] if weights_deltas[i] is not None else None
for i in range(len(hypernet_outputs))]
input_is_latent = (not input_code)
images, result_latent, _ = self.decoder([codes],
weights_deltas=weights_deltas,
input_is_latent=input_is_latent,
randomize_noise=randomize_noise,
return_latents=return_latents)
if resize:
images = self.face_pool(images)
if return_latents and return_weight_deltas_and_codes:
return images, result_latent, weights_deltas, codes, y_hat
elif return_latents:
return images, result_latent
elif return_weight_deltas_and_codes:
return images, weights_deltas, codes
else:
return images
def set_opts(self, opts):
self.opts = opts
def __load_latent_avg(self, ckpt, repeat=None):
if 'latent_avg' in ckpt:
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
if repeat is not None:
self.latent_avg = self.latent_avg.repeat(repeat, 1)
else:
self.latent_avg = None
def __get_hypernet_checkpoint(self):
print('Loading hypernet weights from resnet34!')
hypernet_ckpt = torch.load(model_paths['resnet34'])
# Transfer the RGB input of the resnet34 network to the first 3 input channels of hypernet
if self.opts.input_nc != 3:
shape = hypernet_ckpt['conv1.weight'].shape
altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32)
altered_input_layer[:, :3, :, :] = hypernet_ckpt['conv1.weight']
hypernet_ckpt['conv1.weight'] = altered_input_layer
mapped_hypernet_ckpt = dict(hypernet_ckpt)
for p, v in hypernet_ckpt.items():
for original_name, net_name in RESNET_MAPPING.items():
if original_name in p:
mapped_hypernet_ckpt[p.replace(original_name, net_name)] = v
mapped_hypernet_ckpt.pop(p)
return hypernet_ckpt
@staticmethod
def __get_keys(d, name):
if 'state_dict' in d:
d = d['state_dict']
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
return d_filt
def __get_pretrained_w_encoder(self):
print("Loading pretrained W encoder...")
opts_w_encoder = vars(copy.deepcopy(self.opts))
opts_w_encoder['checkpoint_path'] = self.opts.w_encoder_checkpoint_path
opts_w_encoder['encoder_type'] = self.opts.w_encoder_type
opts_w_encoder['input_nc'] = 3
opts_w_encoder = Namespace(**opts_w_encoder)
w_net = pSp(opts_w_encoder)
w_net = w_net.encoder
w_net.eval()
w_net.to(self.opts.device)
return w_net
def __get_initial_inversion(self, x, resize=True):
# get initial inversion and reconstruction of batch
with torch.no_grad():
return self.__get_w_inversion(x, resize)
def __get_w_inversion(self, x, resize=True):
if self.w_encoder.training:
self.w_encoder.eval()
codes = self.w_encoder.forward(x)
if codes.ndim == 2:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
else:
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
y_hat, _, _ = self.decoder([codes],
weights_deltas=None,
input_is_latent=True,
randomize_noise=False,
return_latents=False)
if resize:
y_hat = self.face_pool(y_hat)
if "cars" in self.opts.dataset_type:
y_hat = y_hat[:, :, 32:224, :]
return y_hat, codes
def w_invert(self, x, resize=True):
with torch.no_grad():
return self.__get_w_inversion(x, resize)