ethanNeuralImage commited on
Commit
7d75862
1 Parent(s): 822dd00
mapper/__init__.py ADDED
File without changes
mapper/latent_mappers.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Module
4
+
5
+ from models.stylegan2.model import EqualLinear, PixelNorm
6
+
7
+ from models.hyperstyle.hypernetworks.refinement_blocks import PARAMETERS as HYPERSTYLE_PARAMETERS
8
+
9
+ STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128, 128] + [64, 64, 64] + [32, 32]
10
+
11
+
12
+ class Mapper(Module):
13
+
14
+ def __init__(self, opts, latent_dim=512):
15
+ super(Mapper, self).__init__()
16
+
17
+ self.opts = opts
18
+ layers = [PixelNorm()]
19
+
20
+ for i in range(4):
21
+ layers.append(
22
+ EqualLinear(
23
+ latent_dim, latent_dim, lr_mul=0.01, activation='fused_lrelu'
24
+ )
25
+ )
26
+
27
+ self.mapping = nn.Sequential(*layers)
28
+
29
+
30
+ def forward(self, x):
31
+ x = self.mapping(x)
32
+ return x
33
+
34
+
35
+ class SingleMapper(Module):
36
+
37
+ def __init__(self, opts):
38
+ super(SingleMapper, self).__init__()
39
+
40
+ self.opts = opts
41
+
42
+ self.mapping = Mapper(opts)
43
+
44
+ def forward(self, x):
45
+ out = self.mapping(x)
46
+ return out
47
+
48
+
49
+ class LevelsMapper(Module):
50
+
51
+ def __init__(self, opts):
52
+ super(LevelsMapper, self).__init__()
53
+
54
+ self.opts = opts
55
+
56
+ if not opts.no_coarse_mapper:
57
+ self.course_mapping = Mapper(opts)
58
+ if not opts.no_medium_mapper:
59
+ self.medium_mapping = Mapper(opts)
60
+ if not opts.no_fine_mapper:
61
+ self.fine_mapping = Mapper(opts)
62
+
63
+ def forward(self, x):
64
+ x_coarse = x[:, :4, :]
65
+ x_medium = x[:, 4:8, :]
66
+ x_fine = x[:, 8:, :]
67
+
68
+ if not self.opts.no_coarse_mapper:
69
+ x_coarse = self.course_mapping(x_coarse)
70
+ else:
71
+ x_coarse = torch.zeros_like(x_coarse)
72
+ if not self.opts.no_medium_mapper:
73
+ x_medium = self.medium_mapping(x_medium)
74
+ else:
75
+ x_medium = torch.zeros_like(x_medium)
76
+ if not self.opts.no_fine_mapper:
77
+ x_fine = self.fine_mapping(x_fine)
78
+ else:
79
+ x_fine = torch.zeros_like(x_fine)
80
+
81
+
82
+ out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
83
+
84
+ return out
85
+
86
+ class FullStyleSpaceMapper(Module):
87
+
88
+ def __init__(self, opts):
89
+ super(FullStyleSpaceMapper, self).__init__()
90
+
91
+ self.opts = opts
92
+
93
+ for c, c_dim in enumerate(STYLESPACE_DIMENSIONS):
94
+ setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=c_dim))
95
+
96
+ def forward(self, x):
97
+ out = []
98
+ for c, x_c in enumerate(x):
99
+ curr_mapper = getattr(self, f"mapper_{c}")
100
+ x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
101
+ out.append(x_c_res)
102
+
103
+ return out
104
+
105
+
106
+ class WithoutToRGBStyleSpaceMapper(Module):
107
+
108
+ def __init__(self, opts):
109
+ super(WithoutToRGBStyleSpaceMapper, self).__init__()
110
+
111
+ self.opts = opts
112
+
113
+ indices_without_torgb = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
114
+ self.STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in indices_without_torgb]
115
+
116
+ for c in self.STYLESPACE_INDICES_WITHOUT_TORGB:
117
+ setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=STYLESPACE_DIMENSIONS[c]))
118
+
119
+ def forward(self, x):
120
+ out = []
121
+ for c in range(len(STYLESPACE_DIMENSIONS)):
122
+ x_c = x[c]
123
+ if c in self.STYLESPACE_INDICES_WITHOUT_TORGB:
124
+ curr_mapper = getattr(self, f"mapper_{c}")
125
+ x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
126
+ else:
127
+ x_c_res = torch.zeros_like(x_c)
128
+ out.append(x_c_res)
129
+
130
+ return out
131
+
132
+
133
+ class WeightDeltasMapper(Module):
134
+
135
+ def __init__(self, opts):
136
+ super(WeightDeltasMapper, self).__init__()
137
+
138
+ self.opts = opts
139
+ self.weight_deltas_indicies = [int(l) for l in opts.layers_to_tune.split(',')]
140
+
141
+ for c in self.weight_deltas_indicies:
142
+ _, _, latent_dim = HYPERSTYLE_PARAMETERS[c]
143
+ setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=latent_dim))
144
+
145
+ def forward(self, x):
146
+ out = []
147
+ for c in range(len(STYLESPACE_DIMENSIONS)):
148
+ x_c = x[c]
149
+ if c in self.weight_deltas_indicies:
150
+ curr_mapper = getattr(self, f"mapper_{c}")
151
+ x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
152
+ else:
153
+ x_c_res = None
154
+ out.append(x_c_res)
155
+
156
+ return out
mapper/styleclip_mapper.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from mapper import latent_mappers
4
+ from models.stylegan2.model import Generator
5
+
6
+
7
+ def get_keys(d, name):
8
+ if 'state_dict' in d:
9
+ d = d['state_dict']
10
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
11
+ return d_filt
12
+
13
+
14
+ class StyleCLIPMapper(nn.Module):
15
+
16
+ def __init__(self, opts):
17
+ super(StyleCLIPMapper, self).__init__()
18
+ self.opts = opts
19
+ # Define architecture
20
+ self.mapper = self.set_mapper()
21
+ self.decoder = Generator(self.opts.stylegan_size, 512, 8)
22
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
23
+ if self.opts.use_weight_delta_mapper:
24
+ self.delta_mapper = latent_mappers.WeightDeltasMapper(self.opts)
25
+ # Load weights if needed
26
+ self.load_weights()
27
+
28
+ def set_mapper(self):
29
+ if self.opts.work_in_stylespace:
30
+ mapper = latent_mappers.WithoutToRGBStyleSpaceMapper(self.opts)
31
+ elif self.opts.mapper_type == 'SingleMapper':
32
+ mapper = latent_mappers.SingleMapper(self.opts)
33
+ elif self.opts.mapper_type == 'LevelsMapper':
34
+ mapper = latent_mappers.LevelsMapper(self.opts)
35
+ else:
36
+ raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type))
37
+ return mapper
38
+
39
+ def load_weights(self):
40
+ if self.opts.checkpoint_path is not None:
41
+ print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
42
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
43
+ self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True)
44
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
45
+ if self.opts.use_weight_delta_mapper:
46
+ self.delta_mapper.load_state_dict(get_keys(ckpt, 'delta_mapper'), strict=True)
47
+ else:
48
+ print('Loading decoder weights from pretrained!')
49
+ ckpt = torch.load(self.opts.stylegan_weights)
50
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
51
+
52
+ def forward(self, x, weights_deltas=None, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
53
+ inject_latent=None, return_latents=False, alpha=None):
54
+ if input_code:
55
+ codes = x
56
+ else:
57
+ codes = self.mapper(x)
58
+
59
+ if weights_deltas is not None and self.opts.use_weight_delta_mapper:
60
+ weights_deltas = self.weight_deltas_mapper(weights_deltas)
61
+
62
+ if latent_mask is not None:
63
+ for i in latent_mask:
64
+ if inject_latent is not None:
65
+ if alpha is not None:
66
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
67
+ else:
68
+ codes[:, i] = inject_latent[:, i]
69
+ else:
70
+ codes[:, i] = 0
71
+
72
+ input_is_latent = not input_code
73
+ images, result_latent = self.decoder([codes], weights_deltas=weights_deltas,
74
+ input_is_latent=input_is_latent,
75
+ randomize_noise=randomize_noise,
76
+ return_latents=return_latents)
77
+
78
+ if resize:
79
+ images = self.face_pool(images)
80
+
81
+ if return_latents:
82
+ return images, result_latent
83
+ else:
84
+ return images
models/__init__.py ADDED
File without changes