File size: 2,713 Bytes
4d9fdb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from torch import nn
from models.StyleCLIP.mapper import latent_mappers
from models.StyleCLIP.models.stylegan2.model import Generator


def get_keys(d, name):
    if 'state_dict' in d:
        d = d['state_dict']
    d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
    return d_filt


class StyleCLIPMapper(nn.Module):

    def __init__(self, opts, run_id):
        super(StyleCLIPMapper, self).__init__()
        self.opts = opts
        # Define architecture
        self.mapper = self.set_mapper()
        self.run_id = run_id

        self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
        # Load weights if needed
        self.load_weights()

    def set_mapper(self):
        if self.opts.mapper_type == 'SingleMapper':
            mapper = latent_mappers.SingleMapper(self.opts)
        elif self.opts.mapper_type == 'LevelsMapper':
            mapper = latent_mappers.LevelsMapper(self.opts)
        else:
            raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type))
        return mapper

    def load_weights(self):
        if self.opts.checkpoint_path is not None:
            print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
            ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
            self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True)

    def set_G(self, new_G):
        self.decoder = new_G

    def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
                inject_latent=None, return_latents=False, alpha=None):
        if input_code:
            codes = x
        else:
            codes = self.mapper(x)

        if latent_mask is not None:
            for i in latent_mask:
                if inject_latent is not None:
                    if alpha is not None:
                        codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
                    else:
                        codes[:, i] = inject_latent[:, i]
                else:
                    codes[:, i] = 0

        input_is_latent = not input_code
        images = self.decoder.synthesis(codes, noise_mode='const')
        result_latent = None
        # images, result_latent = self.decoder([codes],
        #                                      input_is_latent=input_is_latent,
        #                                      randomize_noise=randomize_noise,
        #                                      return_latents=return_latents)

        if resize:
            images = self.face_pool(images)

        if return_latents:
            return images, result_latent
        else:
            return images