Spaces:
Runtime error
Runtime error
ethanNeuralImage
commited on
Commit
•
7d75862
1
Parent(s):
822dd00
mapper
Browse files- mapper/__init__.py +0 -0
- mapper/latent_mappers.py +156 -0
- mapper/styleclip_mapper.py +84 -0
- models/__init__.py +0 -0
mapper/__init__.py
ADDED
File without changes
|
mapper/latent_mappers.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import Module
|
4 |
+
|
5 |
+
from models.stylegan2.model import EqualLinear, PixelNorm
|
6 |
+
|
7 |
+
from models.hyperstyle.hypernetworks.refinement_blocks import PARAMETERS as HYPERSTYLE_PARAMETERS
|
8 |
+
|
9 |
+
STYLESPACE_DIMENSIONS = [512 for _ in range(15)] + [256, 256, 256] + [128, 128, 128] + [64, 64, 64] + [32, 32]
|
10 |
+
|
11 |
+
|
12 |
+
class Mapper(Module):
|
13 |
+
|
14 |
+
def __init__(self, opts, latent_dim=512):
|
15 |
+
super(Mapper, self).__init__()
|
16 |
+
|
17 |
+
self.opts = opts
|
18 |
+
layers = [PixelNorm()]
|
19 |
+
|
20 |
+
for i in range(4):
|
21 |
+
layers.append(
|
22 |
+
EqualLinear(
|
23 |
+
latent_dim, latent_dim, lr_mul=0.01, activation='fused_lrelu'
|
24 |
+
)
|
25 |
+
)
|
26 |
+
|
27 |
+
self.mapping = nn.Sequential(*layers)
|
28 |
+
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
x = self.mapping(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class SingleMapper(Module):
|
36 |
+
|
37 |
+
def __init__(self, opts):
|
38 |
+
super(SingleMapper, self).__init__()
|
39 |
+
|
40 |
+
self.opts = opts
|
41 |
+
|
42 |
+
self.mapping = Mapper(opts)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
out = self.mapping(x)
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
class LevelsMapper(Module):
|
50 |
+
|
51 |
+
def __init__(self, opts):
|
52 |
+
super(LevelsMapper, self).__init__()
|
53 |
+
|
54 |
+
self.opts = opts
|
55 |
+
|
56 |
+
if not opts.no_coarse_mapper:
|
57 |
+
self.course_mapping = Mapper(opts)
|
58 |
+
if not opts.no_medium_mapper:
|
59 |
+
self.medium_mapping = Mapper(opts)
|
60 |
+
if not opts.no_fine_mapper:
|
61 |
+
self.fine_mapping = Mapper(opts)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
x_coarse = x[:, :4, :]
|
65 |
+
x_medium = x[:, 4:8, :]
|
66 |
+
x_fine = x[:, 8:, :]
|
67 |
+
|
68 |
+
if not self.opts.no_coarse_mapper:
|
69 |
+
x_coarse = self.course_mapping(x_coarse)
|
70 |
+
else:
|
71 |
+
x_coarse = torch.zeros_like(x_coarse)
|
72 |
+
if not self.opts.no_medium_mapper:
|
73 |
+
x_medium = self.medium_mapping(x_medium)
|
74 |
+
else:
|
75 |
+
x_medium = torch.zeros_like(x_medium)
|
76 |
+
if not self.opts.no_fine_mapper:
|
77 |
+
x_fine = self.fine_mapping(x_fine)
|
78 |
+
else:
|
79 |
+
x_fine = torch.zeros_like(x_fine)
|
80 |
+
|
81 |
+
|
82 |
+
out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
|
83 |
+
|
84 |
+
return out
|
85 |
+
|
86 |
+
class FullStyleSpaceMapper(Module):
|
87 |
+
|
88 |
+
def __init__(self, opts):
|
89 |
+
super(FullStyleSpaceMapper, self).__init__()
|
90 |
+
|
91 |
+
self.opts = opts
|
92 |
+
|
93 |
+
for c, c_dim in enumerate(STYLESPACE_DIMENSIONS):
|
94 |
+
setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=c_dim))
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
out = []
|
98 |
+
for c, x_c in enumerate(x):
|
99 |
+
curr_mapper = getattr(self, f"mapper_{c}")
|
100 |
+
x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
|
101 |
+
out.append(x_c_res)
|
102 |
+
|
103 |
+
return out
|
104 |
+
|
105 |
+
|
106 |
+
class WithoutToRGBStyleSpaceMapper(Module):
|
107 |
+
|
108 |
+
def __init__(self, opts):
|
109 |
+
super(WithoutToRGBStyleSpaceMapper, self).__init__()
|
110 |
+
|
111 |
+
self.opts = opts
|
112 |
+
|
113 |
+
indices_without_torgb = list(range(1, len(STYLESPACE_DIMENSIONS), 3))
|
114 |
+
self.STYLESPACE_INDICES_WITHOUT_TORGB = [i for i in range(len(STYLESPACE_DIMENSIONS)) if i not in indices_without_torgb]
|
115 |
+
|
116 |
+
for c in self.STYLESPACE_INDICES_WITHOUT_TORGB:
|
117 |
+
setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=STYLESPACE_DIMENSIONS[c]))
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
out = []
|
121 |
+
for c in range(len(STYLESPACE_DIMENSIONS)):
|
122 |
+
x_c = x[c]
|
123 |
+
if c in self.STYLESPACE_INDICES_WITHOUT_TORGB:
|
124 |
+
curr_mapper = getattr(self, f"mapper_{c}")
|
125 |
+
x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
|
126 |
+
else:
|
127 |
+
x_c_res = torch.zeros_like(x_c)
|
128 |
+
out.append(x_c_res)
|
129 |
+
|
130 |
+
return out
|
131 |
+
|
132 |
+
|
133 |
+
class WeightDeltasMapper(Module):
|
134 |
+
|
135 |
+
def __init__(self, opts):
|
136 |
+
super(WeightDeltasMapper, self).__init__()
|
137 |
+
|
138 |
+
self.opts = opts
|
139 |
+
self.weight_deltas_indicies = [int(l) for l in opts.layers_to_tune.split(',')]
|
140 |
+
|
141 |
+
for c in self.weight_deltas_indicies:
|
142 |
+
_, _, latent_dim = HYPERSTYLE_PARAMETERS[c]
|
143 |
+
setattr(self, f"mapper_{c}", Mapper(opts, latent_dim=latent_dim))
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
out = []
|
147 |
+
for c in range(len(STYLESPACE_DIMENSIONS)):
|
148 |
+
x_c = x[c]
|
149 |
+
if c in self.weight_deltas_indicies:
|
150 |
+
curr_mapper = getattr(self, f"mapper_{c}")
|
151 |
+
x_c_res = curr_mapper(x_c.view(x_c.shape[0], -1)).view(x_c.shape)
|
152 |
+
else:
|
153 |
+
x_c_res = None
|
154 |
+
out.append(x_c_res)
|
155 |
+
|
156 |
+
return out
|
mapper/styleclip_mapper.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from mapper import latent_mappers
|
4 |
+
from models.stylegan2.model import Generator
|
5 |
+
|
6 |
+
|
7 |
+
def get_keys(d, name):
|
8 |
+
if 'state_dict' in d:
|
9 |
+
d = d['state_dict']
|
10 |
+
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
|
11 |
+
return d_filt
|
12 |
+
|
13 |
+
|
14 |
+
class StyleCLIPMapper(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, opts):
|
17 |
+
super(StyleCLIPMapper, self).__init__()
|
18 |
+
self.opts = opts
|
19 |
+
# Define architecture
|
20 |
+
self.mapper = self.set_mapper()
|
21 |
+
self.decoder = Generator(self.opts.stylegan_size, 512, 8)
|
22 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
23 |
+
if self.opts.use_weight_delta_mapper:
|
24 |
+
self.delta_mapper = latent_mappers.WeightDeltasMapper(self.opts)
|
25 |
+
# Load weights if needed
|
26 |
+
self.load_weights()
|
27 |
+
|
28 |
+
def set_mapper(self):
|
29 |
+
if self.opts.work_in_stylespace:
|
30 |
+
mapper = latent_mappers.WithoutToRGBStyleSpaceMapper(self.opts)
|
31 |
+
elif self.opts.mapper_type == 'SingleMapper':
|
32 |
+
mapper = latent_mappers.SingleMapper(self.opts)
|
33 |
+
elif self.opts.mapper_type == 'LevelsMapper':
|
34 |
+
mapper = latent_mappers.LevelsMapper(self.opts)
|
35 |
+
else:
|
36 |
+
raise Exception('{} is not a valid mapper'.format(self.opts.mapper_type))
|
37 |
+
return mapper
|
38 |
+
|
39 |
+
def load_weights(self):
|
40 |
+
if self.opts.checkpoint_path is not None:
|
41 |
+
print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
|
42 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
|
43 |
+
self.mapper.load_state_dict(get_keys(ckpt, 'mapper'), strict=True)
|
44 |
+
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
45 |
+
if self.opts.use_weight_delta_mapper:
|
46 |
+
self.delta_mapper.load_state_dict(get_keys(ckpt, 'delta_mapper'), strict=True)
|
47 |
+
else:
|
48 |
+
print('Loading decoder weights from pretrained!')
|
49 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
50 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
51 |
+
|
52 |
+
def forward(self, x, weights_deltas=None, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
|
53 |
+
inject_latent=None, return_latents=False, alpha=None):
|
54 |
+
if input_code:
|
55 |
+
codes = x
|
56 |
+
else:
|
57 |
+
codes = self.mapper(x)
|
58 |
+
|
59 |
+
if weights_deltas is not None and self.opts.use_weight_delta_mapper:
|
60 |
+
weights_deltas = self.weight_deltas_mapper(weights_deltas)
|
61 |
+
|
62 |
+
if latent_mask is not None:
|
63 |
+
for i in latent_mask:
|
64 |
+
if inject_latent is not None:
|
65 |
+
if alpha is not None:
|
66 |
+
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
|
67 |
+
else:
|
68 |
+
codes[:, i] = inject_latent[:, i]
|
69 |
+
else:
|
70 |
+
codes[:, i] = 0
|
71 |
+
|
72 |
+
input_is_latent = not input_code
|
73 |
+
images, result_latent = self.decoder([codes], weights_deltas=weights_deltas,
|
74 |
+
input_is_latent=input_is_latent,
|
75 |
+
randomize_noise=randomize_noise,
|
76 |
+
return_latents=return_latents)
|
77 |
+
|
78 |
+
if resize:
|
79 |
+
images = self.face_pool(images)
|
80 |
+
|
81 |
+
if return_latents:
|
82 |
+
return images, result_latent
|
83 |
+
else:
|
84 |
+
return images
|
models/__init__.py
ADDED
File without changes
|