File size: 7,805 Bytes
92ec8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fa3e0e
92ec8d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import math
import torch
from torch import nn
import copy
from argparse import Namespace

from models.hyperstyle.encoders.psp import pSp
from models.stylegan2.model import Generator
from models.hyperstyle.configs.paths_config import model_paths
from models.hyperstyle.hypernetworks.hypernetwork import SharedWeightsHyperNetResNet, SharedWeightsHyperNetResNetSeparable
from models.hyperstyle.utils.resnet_mapping import RESNET_MAPPING


class HyperStyle(nn.Module):

    def __init__(self, opts):
        super(HyperStyle, self).__init__()
        self.set_opts(opts)
        self.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
        # Define architecture
        self.hypernet = self.set_hypernet()
        self.decoder = Generator(self.opts.output_size, 512, 8, channel_multiplier=2)
        self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
        # Load weights if needed
        self.load_weights()
        if self.opts.load_w_encoder:
            self.w_encoder.eval()

    def set_hypernet(self):
        if self.opts.output_size == 1024:
            self.opts.n_hypernet_outputs = 26
        elif self.opts.output_size == 512:
            self.opts.n_hypernet_outputs = 23
        elif self.opts.output_size == 256:
            self.opts.n_hypernet_outputs = 20
        else:
            raise ValueError(f"Invalid Output Size! Support sizes: [1024, 512, 256]!")
        networks = {
            "SharedWeightsHyperNetResNet": SharedWeightsHyperNetResNet(opts=self.opts),
            "SharedWeightsHyperNetResNetSeparable": SharedWeightsHyperNetResNetSeparable(opts=self.opts),
        }
        return networks[self.opts.encoder_type]

    def load_weights(self):
        if self.opts.checkpoint_path is not None:
            print(f'Loading HyperStyle from checkpoint: {self.opts.checkpoint_path}')
            ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
            self.hypernet.load_state_dict(self.__get_keys(ckpt, 'hypernet'), strict=True)
            self.decoder.load_state_dict(self.__get_keys(ckpt, 'decoder'), strict=True)
            self.__load_latent_avg(ckpt)
            if self.opts.load_w_encoder:
                self.w_encoder = self.__get_pretrained_w_encoder()
        else:
            hypernet_ckpt = self.__get_hypernet_checkpoint()
            self.hypernet.load_state_dict(hypernet_ckpt, strict=False)
            print(f'Loading decoder weights from pretrained path: {self.opts.stylegan_weights}')
            ckpt = torch.load(self.opts.stylegan_weights)
            self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
            self.__load_latent_avg(ckpt, repeat=self.n_styles)
            if self.opts.load_w_encoder:
                self.w_encoder = self.__get_pretrained_w_encoder()

    def forward(self, x, resize=True, input_code=False, randomize_noise=True, return_latents=False,
                return_weight_deltas_and_codes=False, weights_deltas=None, y_hat=None, codes=None):

        if input_code:
            codes = x
        else:
            if y_hat is None:
                assert self.opts.load_w_encoder, "Cannot infer latent code when e4e isn't loaded."
                y_hat, codes = self.__get_initial_inversion(x, resize=True)

            # concatenate original input with w-reconstruction or current reconstruction
            x_input = torch.cat([x, y_hat], dim=1)

            # pass through hypernet to get per-layer deltas
            hypernet_outputs = self.hypernet(x_input)
            if weights_deltas is None:
                weights_deltas = hypernet_outputs
            else:
                weights_deltas = [weights_deltas[i] + hypernet_outputs[i] if weights_deltas[i] is not None else None
                                  for i in range(len(hypernet_outputs))]

        input_is_latent = (not input_code)
        images, result_latent, _ = self.decoder([codes],
                                             weights_deltas=weights_deltas,
                                             input_is_latent=input_is_latent,
                                             randomize_noise=randomize_noise,
                                             return_latents=return_latents)

        if resize:
            images = self.face_pool(images)

        if return_latents and return_weight_deltas_and_codes:
            return images, result_latent, weights_deltas, codes, y_hat
        elif return_latents:
            return images, result_latent
        elif return_weight_deltas_and_codes:
            return images, weights_deltas, codes
        else:
            return images

    def set_opts(self, opts):
        self.opts = opts

    def __load_latent_avg(self, ckpt, repeat=None):
        if 'latent_avg' in ckpt:
            self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
            if repeat is not None:
                self.latent_avg = self.latent_avg.repeat(repeat, 1)
        else:
            self.latent_avg = None

    def __get_hypernet_checkpoint(self):
        print('Loading hypernet weights from resnet34!')
        hypernet_ckpt = torch.load(model_paths['resnet34'])
        # Transfer the RGB input of the resnet34 network to the first 3 input channels of hypernet
        if self.opts.input_nc != 3:
            shape = hypernet_ckpt['conv1.weight'].shape
            altered_input_layer = torch.randn(shape[0], self.opts.input_nc, shape[2], shape[3], dtype=torch.float32)
            altered_input_layer[:, :3, :, :] = hypernet_ckpt['conv1.weight']
            hypernet_ckpt['conv1.weight'] = altered_input_layer
        mapped_hypernet_ckpt = dict(hypernet_ckpt)
        for p, v in hypernet_ckpt.items():
            for original_name, net_name in RESNET_MAPPING.items():
                if original_name in p:
                    mapped_hypernet_ckpt[p.replace(original_name, net_name)] = v
                    mapped_hypernet_ckpt.pop(p)
        return hypernet_ckpt

    @staticmethod
    def __get_keys(d, name):
        if 'state_dict' in d:
            d = d['state_dict']
        d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
        return d_filt

    def __get_pretrained_w_encoder(self):
        print("Loading pretrained W encoder...")
        opts_w_encoder = vars(copy.deepcopy(self.opts))
        opts_w_encoder['checkpoint_path'] = self.opts.w_encoder_checkpoint_path
        opts_w_encoder['encoder_type'] = self.opts.w_encoder_type
        opts_w_encoder['input_nc'] = 3
        opts_w_encoder = Namespace(**opts_w_encoder)
        w_net = pSp(opts_w_encoder)
        w_net = w_net.encoder
        w_net.eval()
        w_net.to(self.opts.device)
        return w_net

    def __get_initial_inversion(self, x, resize=True):
        # get initial inversion and reconstruction of batch
        with torch.no_grad():
            return self.__get_w_inversion(x, resize)

    def __get_w_inversion(self, x, resize=True):
        if self.w_encoder.training:
            self.w_encoder.eval()
        codes = self.w_encoder.forward(x)
        if codes.ndim == 2:
            codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
        else:
            codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
        y_hat, _, _ = self.decoder([codes],
                                weights_deltas=None,
                                input_is_latent=True,
                                randomize_noise=False,
                                return_latents=False)
        if resize:
            y_hat = self.face_pool(y_hat)
        if "cars" in self.opts.dataset_type:
            y_hat = y_hat[:, :, 32:224, :]
        return y_hat, codes
    
    def w_invert(self, x, resize=True):
        with torch.no_grad():
            return self.__get_w_inversion(x, resize)