PKUWilliamYang commited on
Commit
ac4ce84
1 Parent(s): d5073e2

Upload 50 files

Browse files
Files changed (50) hide show
  1. models/__init__.py +0 -0
  2. models/bisenet/LICENSE +21 -0
  3. models/bisenet/README.md +68 -0
  4. models/bisenet/model.py +283 -0
  5. models/bisenet/resnet.py +109 -0
  6. models/encoders/__init__.py +0 -0
  7. models/encoders/helpers.py +119 -0
  8. models/encoders/model_irse.py +84 -0
  9. models/encoders/psp_encoders.py +357 -0
  10. models/mtcnn/__init__.py +0 -0
  11. models/mtcnn/mtcnn.py +156 -0
  12. models/mtcnn/mtcnn_pytorch/__init__.py +0 -0
  13. models/mtcnn/mtcnn_pytorch/src/__init__.py +2 -0
  14. models/mtcnn/mtcnn_pytorch/src/align_trans.py +304 -0
  15. models/mtcnn/mtcnn_pytorch/src/box_utils.py +238 -0
  16. models/mtcnn/mtcnn_pytorch/src/detector.py +126 -0
  17. models/mtcnn/mtcnn_pytorch/src/first_stage.py +101 -0
  18. models/mtcnn/mtcnn_pytorch/src/get_nets.py +171 -0
  19. models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py +350 -0
  20. models/mtcnn/mtcnn_pytorch/src/visualization_utils.py +31 -0
  21. models/mtcnn/mtcnn_pytorch/src/weights/onet.npy +3 -0
  22. models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy +3 -0
  23. models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy +3 -0
  24. models/psp.py +147 -0
  25. models/stylegan2/__init__.py +0 -0
  26. models/stylegan2/lpips/__init__.py +161 -0
  27. models/stylegan2/lpips/base_model.py +58 -0
  28. models/stylegan2/lpips/dist_model.py +284 -0
  29. models/stylegan2/lpips/networks_basic.py +187 -0
  30. models/stylegan2/lpips/pretrained_networks.py +181 -0
  31. models/stylegan2/lpips/weights/v0.0/alex.pth +3 -0
  32. models/stylegan2/lpips/weights/v0.0/squeeze.pth +3 -0
  33. models/stylegan2/lpips/weights/v0.0/vgg.pth +3 -0
  34. models/stylegan2/lpips/weights/v0.1/alex.pth +3 -0
  35. models/stylegan2/lpips/weights/v0.1/squeeze.pth +3 -0
  36. models/stylegan2/lpips/weights/v0.1/vgg.pth +3 -0
  37. models/stylegan2/model.py +768 -0
  38. models/stylegan2/op/__init__.py +2 -0
  39. models/stylegan2/op/conv2d_gradfix.py +227 -0
  40. models/stylegan2/op/fused_act.py +34 -0
  41. models/stylegan2/op/readme.md +12 -0
  42. models/stylegan2/op/upfirdn2d.py +61 -0
  43. models/stylegan2/op_ori/__init__.py +2 -0
  44. models/stylegan2/op_ori/fused_act.py +85 -0
  45. models/stylegan2/op_ori/fused_bias_act.cpp +21 -0
  46. models/stylegan2/op_ori/fused_bias_act_kernel.cu +99 -0
  47. models/stylegan2/op_ori/upfirdn2d.cpp +23 -0
  48. models/stylegan2/op_ori/upfirdn2d.py +184 -0
  49. models/stylegan2/op_ori/upfirdn2d_kernel.cu +272 -0
  50. models/stylegan2/simple_augment.py +478 -0
models/__init__.py ADDED
File without changes
models/bisenet/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 zll
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/bisenet/README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # face-parsing.PyTorch
2
+
3
+ <p align="center">
4
+ <a href="https://github.com/zllrunning/face-parsing.PyTorch">
5
+ <img class="page-image" src="https://github.com/zllrunning/face-parsing.PyTorch/blob/master/6.jpg" >
6
+ </a>
7
+ </p>
8
+
9
+ ### Contents
10
+ - [Training](#training)
11
+ - [Demo](#Demo)
12
+ - [References](#references)
13
+
14
+ ## Training
15
+
16
+ 1. Prepare training data:
17
+ -- download [CelebAMask-HQ dataset](https://github.com/switchablenorms/CelebAMask-HQ)
18
+
19
+ -- change file path in the `prepropess_data.py` and run
20
+ ```Shell
21
+ python prepropess_data.py
22
+ ```
23
+
24
+ 2. Train the model using CelebAMask-HQ dataset:
25
+ Just run the train script:
26
+ ```
27
+ $ CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
28
+ ```
29
+
30
+ If you do not wish to train the model, you can download [our pre-trained model](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812) and save it in `res/cp`.
31
+
32
+
33
+ ## Demo
34
+ 1. Evaluate the trained model using:
35
+ ```Shell
36
+ # evaluate using GPU
37
+ python test.py
38
+ ```
39
+
40
+ ## Face makeup using parsing maps
41
+ [**face-makeup.PyTorch**](https://github.com/zllrunning/face-makeup.PyTorch)
42
+ <table>
43
+
44
+ <tr>
45
+ <th>&nbsp;</th>
46
+ <th>Hair</th>
47
+ <th>Lip</th>
48
+ </tr>
49
+
50
+ <!-- Line 1: Original Input -->
51
+ <tr>
52
+ <td><em>Original Input</em></td>
53
+ <td><img src="makeup/116_ori.png" height="256" width="256" alt="Original Input"></td>
54
+ <td><img src="makeup/116_lip_ori.png" height="256" width="256" alt="Original Input"></td>
55
+ </tr>
56
+
57
+ <!-- Line 3: Color -->
58
+ <tr>
59
+ <td>Color</td>
60
+ <td><img src="makeup/116_1.png" height="256" width="256" alt="Color"></td>
61
+ <td><img src="makeup/116_3.png" height="256" width="256" alt="Color"></td>
62
+ </tr>
63
+
64
+ </table>
65
+
66
+
67
+ ## References
68
+ - [BiSeNet](https://github.com/CoinCheung/BiSeNet)
models/bisenet/model.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+
10
+ from models.bisenet.resnet import Resnet18
11
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
12
+
13
+
14
+ class ConvBNReLU(nn.Module):
15
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16
+ super(ConvBNReLU, self).__init__()
17
+ self.conv = nn.Conv2d(in_chan,
18
+ out_chan,
19
+ kernel_size = ks,
20
+ stride = stride,
21
+ padding = padding,
22
+ bias = False)
23
+ self.bn = nn.BatchNorm2d(out_chan)
24
+ self.init_weight()
25
+
26
+ def forward(self, x):
27
+ x = self.conv(x)
28
+ x = F.relu(self.bn(x))
29
+ return x
30
+
31
+ def init_weight(self):
32
+ for ly in self.children():
33
+ if isinstance(ly, nn.Conv2d):
34
+ nn.init.kaiming_normal_(ly.weight, a=1)
35
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36
+
37
+ class BiSeNetOutput(nn.Module):
38
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39
+ super(BiSeNetOutput, self).__init__()
40
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42
+ self.init_weight()
43
+
44
+ def forward(self, x):
45
+ x = self.conv(x)
46
+ x = self.conv_out(x)
47
+ return x
48
+
49
+ def init_weight(self):
50
+ for ly in self.children():
51
+ if isinstance(ly, nn.Conv2d):
52
+ nn.init.kaiming_normal_(ly.weight, a=1)
53
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54
+
55
+ def get_params(self):
56
+ wd_params, nowd_params = [], []
57
+ for name, module in self.named_modules():
58
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59
+ wd_params.append(module.weight)
60
+ if not module.bias is None:
61
+ nowd_params.append(module.bias)
62
+ elif isinstance(module, nn.BatchNorm2d):
63
+ nowd_params += list(module.parameters())
64
+ return wd_params, nowd_params
65
+
66
+
67
+ class AttentionRefinementModule(nn.Module):
68
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
69
+ super(AttentionRefinementModule, self).__init__()
70
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72
+ self.bn_atten = nn.BatchNorm2d(out_chan)
73
+ self.sigmoid_atten = nn.Sigmoid()
74
+ self.init_weight()
75
+
76
+ def forward(self, x):
77
+ feat = self.conv(x)
78
+ atten = F.avg_pool2d(feat, feat.size()[2:])
79
+ atten = self.conv_atten(atten)
80
+ atten = self.bn_atten(atten)
81
+ atten = self.sigmoid_atten(atten)
82
+ out = torch.mul(feat, atten)
83
+ return out
84
+
85
+ def init_weight(self):
86
+ for ly in self.children():
87
+ if isinstance(ly, nn.Conv2d):
88
+ nn.init.kaiming_normal_(ly.weight, a=1)
89
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90
+
91
+
92
+ class ContextPath(nn.Module):
93
+ def __init__(self, *args, **kwargs):
94
+ super(ContextPath, self).__init__()
95
+ self.resnet = Resnet18()
96
+ self.arm16 = AttentionRefinementModule(256, 128)
97
+ self.arm32 = AttentionRefinementModule(512, 128)
98
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101
+
102
+ self.init_weight()
103
+
104
+ def forward(self, x):
105
+ H0, W0 = x.size()[2:]
106
+ feat8, feat16, feat32 = self.resnet(x)
107
+ H8, W8 = feat8.size()[2:]
108
+ H16, W16 = feat16.size()[2:]
109
+ H32, W32 = feat32.size()[2:]
110
+
111
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
112
+ avg = self.conv_avg(avg)
113
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114
+
115
+ feat32_arm = self.arm32(feat32)
116
+ feat32_sum = feat32_arm + avg_up
117
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118
+ feat32_up = self.conv_head32(feat32_up)
119
+
120
+ feat16_arm = self.arm16(feat16)
121
+ feat16_sum = feat16_arm + feat32_up
122
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123
+ feat16_up = self.conv_head16(feat16_up)
124
+
125
+ return feat8, feat16_up, feat32_up # x8, x8, x16
126
+
127
+ def init_weight(self):
128
+ for ly in self.children():
129
+ if isinstance(ly, nn.Conv2d):
130
+ nn.init.kaiming_normal_(ly.weight, a=1)
131
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132
+
133
+ def get_params(self):
134
+ wd_params, nowd_params = [], []
135
+ for name, module in self.named_modules():
136
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
137
+ wd_params.append(module.weight)
138
+ if not module.bias is None:
139
+ nowd_params.append(module.bias)
140
+ elif isinstance(module, nn.BatchNorm2d):
141
+ nowd_params += list(module.parameters())
142
+ return wd_params, nowd_params
143
+
144
+
145
+ ### This is not used, since I replace this with the resnet feature with the same size
146
+ class SpatialPath(nn.Module):
147
+ def __init__(self, *args, **kwargs):
148
+ super(SpatialPath, self).__init__()
149
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153
+ self.init_weight()
154
+
155
+ def forward(self, x):
156
+ feat = self.conv1(x)
157
+ feat = self.conv2(feat)
158
+ feat = self.conv3(feat)
159
+ feat = self.conv_out(feat)
160
+ return feat
161
+
162
+ def init_weight(self):
163
+ for ly in self.children():
164
+ if isinstance(ly, nn.Conv2d):
165
+ nn.init.kaiming_normal_(ly.weight, a=1)
166
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172
+ wd_params.append(module.weight)
173
+ if not module.bias is None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ class FeatureFusionModule(nn.Module):
181
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
182
+ super(FeatureFusionModule, self).__init__()
183
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184
+ self.conv1 = nn.Conv2d(out_chan,
185
+ out_chan//4,
186
+ kernel_size = 1,
187
+ stride = 1,
188
+ padding = 0,
189
+ bias = False)
190
+ self.conv2 = nn.Conv2d(out_chan//4,
191
+ out_chan,
192
+ kernel_size = 1,
193
+ stride = 1,
194
+ padding = 0,
195
+ bias = False)
196
+ self.relu = nn.ReLU(inplace=True)
197
+ self.sigmoid = nn.Sigmoid()
198
+ self.init_weight()
199
+
200
+ def forward(self, fsp, fcp):
201
+ fcat = torch.cat([fsp, fcp], dim=1)
202
+ feat = self.convblk(fcat)
203
+ atten = F.avg_pool2d(feat, feat.size()[2:])
204
+ atten = self.conv1(atten)
205
+ atten = self.relu(atten)
206
+ atten = self.conv2(atten)
207
+ atten = self.sigmoid(atten)
208
+ feat_atten = torch.mul(feat, atten)
209
+ feat_out = feat_atten + feat
210
+ return feat_out
211
+
212
+ def init_weight(self):
213
+ for ly in self.children():
214
+ if isinstance(ly, nn.Conv2d):
215
+ nn.init.kaiming_normal_(ly.weight, a=1)
216
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217
+
218
+ def get_params(self):
219
+ wd_params, nowd_params = [], []
220
+ for name, module in self.named_modules():
221
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222
+ wd_params.append(module.weight)
223
+ if not module.bias is None:
224
+ nowd_params.append(module.bias)
225
+ elif isinstance(module, nn.BatchNorm2d):
226
+ nowd_params += list(module.parameters())
227
+ return wd_params, nowd_params
228
+
229
+
230
+ class BiSeNet(nn.Module):
231
+ def __init__(self, n_classes, *args, **kwargs):
232
+ super(BiSeNet, self).__init__()
233
+ self.cp = ContextPath()
234
+ ## here self.sp is deleted
235
+ self.ffm = FeatureFusionModule(256, 256)
236
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
237
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239
+ self.init_weight()
240
+
241
+ def forward(self, x):
242
+ H, W = x.size()[2:]
243
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
246
+
247
+ feat_out = self.conv_out(feat_fuse)
248
+ feat_out16 = self.conv_out16(feat_cp8)
249
+ feat_out32 = self.conv_out32(feat_cp16)
250
+
251
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254
+ return feat_out, feat_out16, feat_out32
255
+
256
+ def init_weight(self):
257
+ for ly in self.children():
258
+ if isinstance(ly, nn.Conv2d):
259
+ nn.init.kaiming_normal_(ly.weight, a=1)
260
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261
+
262
+ def get_params(self):
263
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264
+ for name, child in self.named_children():
265
+ child_wd_params, child_nowd_params = child.get_params()
266
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267
+ lr_mul_wd_params += child_wd_params
268
+ lr_mul_nowd_params += child_nowd_params
269
+ else:
270
+ wd_params += child_wd_params
271
+ nowd_params += child_nowd_params
272
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273
+
274
+
275
+ if __name__ == "__main__":
276
+ net = BiSeNet(19)
277
+ net.cuda()
278
+ net.eval()
279
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
280
+ out, out16, out32 = net(in_ten)
281
+ print(out.shape)
282
+
283
+ net.get_params()
models/bisenet/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
models/encoders/__init__.py ADDED
File without changes
models/encoders/helpers.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4
+
5
+ """
6
+ ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7
+ """
8
+
9
+
10
+ class Flatten(Module):
11
+ def forward(self, input):
12
+ return input.view(input.size(0), -1)
13
+
14
+
15
+ def l2_norm(input, axis=1):
16
+ norm = torch.norm(input, 2, axis, True)
17
+ output = torch.div(input, norm)
18
+ return output
19
+
20
+
21
+ class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
22
+ """ A named tuple describing a ResNet block. """
23
+
24
+
25
+ def get_block(in_channel, depth, num_units, stride=2):
26
+ return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
27
+
28
+
29
+ def get_blocks(num_layers):
30
+ if num_layers == 50:
31
+ blocks = [
32
+ get_block(in_channel=64, depth=64, num_units=3),
33
+ get_block(in_channel=64, depth=128, num_units=4),
34
+ get_block(in_channel=128, depth=256, num_units=14),
35
+ get_block(in_channel=256, depth=512, num_units=3)
36
+ ]
37
+ elif num_layers == 100:
38
+ blocks = [
39
+ get_block(in_channel=64, depth=64, num_units=3),
40
+ get_block(in_channel=64, depth=128, num_units=13),
41
+ get_block(in_channel=128, depth=256, num_units=30),
42
+ get_block(in_channel=256, depth=512, num_units=3)
43
+ ]
44
+ elif num_layers == 152:
45
+ blocks = [
46
+ get_block(in_channel=64, depth=64, num_units=3),
47
+ get_block(in_channel=64, depth=128, num_units=8),
48
+ get_block(in_channel=128, depth=256, num_units=36),
49
+ get_block(in_channel=256, depth=512, num_units=3)
50
+ ]
51
+ else:
52
+ raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
53
+ return blocks
54
+
55
+
56
+ class SEModule(Module):
57
+ def __init__(self, channels, reduction):
58
+ super(SEModule, self).__init__()
59
+ self.avg_pool = AdaptiveAvgPool2d(1)
60
+ self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
61
+ self.relu = ReLU(inplace=True)
62
+ self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
63
+ self.sigmoid = Sigmoid()
64
+
65
+ def forward(self, x):
66
+ module_input = x
67
+ x = self.avg_pool(x)
68
+ x = self.fc1(x)
69
+ x = self.relu(x)
70
+ x = self.fc2(x)
71
+ x = self.sigmoid(x)
72
+ return module_input * x
73
+
74
+
75
+ class bottleneck_IR(Module):
76
+ def __init__(self, in_channel, depth, stride):
77
+ super(bottleneck_IR, self).__init__()
78
+ if in_channel == depth:
79
+ self.shortcut_layer = MaxPool2d(1, stride)
80
+ else:
81
+ self.shortcut_layer = Sequential(
82
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
83
+ BatchNorm2d(depth)
84
+ )
85
+ self.res_layer = Sequential(
86
+ BatchNorm2d(in_channel),
87
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
88
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
89
+ )
90
+
91
+ def forward(self, x):
92
+ shortcut = self.shortcut_layer(x)
93
+ res = self.res_layer(x)
94
+ return res + shortcut
95
+
96
+
97
+ class bottleneck_IR_SE(Module):
98
+ def __init__(self, in_channel, depth, stride):
99
+ super(bottleneck_IR_SE, self).__init__()
100
+ if in_channel == depth:
101
+ self.shortcut_layer = MaxPool2d(1, stride)
102
+ else:
103
+ self.shortcut_layer = Sequential(
104
+ Conv2d(in_channel, depth, (1, 1), stride, bias=False),
105
+ BatchNorm2d(depth)
106
+ )
107
+ self.res_layer = Sequential(
108
+ BatchNorm2d(in_channel),
109
+ Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
110
+ PReLU(depth),
111
+ Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
112
+ BatchNorm2d(depth),
113
+ SEModule(depth, 16)
114
+ )
115
+
116
+ def forward(self, x):
117
+ shortcut = self.shortcut_layer(x)
118
+ res = self.res_layer(x)
119
+ return res + shortcut
models/encoders/model_irse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3
+
4
+ """
5
+ Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6
+ """
7
+
8
+
9
+ class Backbone(Module):
10
+ def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11
+ super(Backbone, self).__init__()
12
+ assert input_size in [112, 224], "input_size should be 112 or 224"
13
+ assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14
+ assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15
+ blocks = get_blocks(num_layers)
16
+ if mode == 'ir':
17
+ unit_module = bottleneck_IR
18
+ elif mode == 'ir_se':
19
+ unit_module = bottleneck_IR_SE
20
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21
+ BatchNorm2d(64),
22
+ PReLU(64))
23
+ if input_size == 112:
24
+ self.output_layer = Sequential(BatchNorm2d(512),
25
+ Dropout(drop_ratio),
26
+ Flatten(),
27
+ Linear(512 * 7 * 7, 512),
28
+ BatchNorm1d(512, affine=affine))
29
+ else:
30
+ self.output_layer = Sequential(BatchNorm2d(512),
31
+ Dropout(drop_ratio),
32
+ Flatten(),
33
+ Linear(512 * 14 * 14, 512),
34
+ BatchNorm1d(512, affine=affine))
35
+
36
+ modules = []
37
+ for block in blocks:
38
+ for bottleneck in block:
39
+ modules.append(unit_module(bottleneck.in_channel,
40
+ bottleneck.depth,
41
+ bottleneck.stride))
42
+ self.body = Sequential(*modules)
43
+
44
+ def forward(self, x):
45
+ x = self.input_layer(x)
46
+ x = self.body(x)
47
+ x = self.output_layer(x)
48
+ return l2_norm(x)
49
+
50
+
51
+ def IR_50(input_size):
52
+ """Constructs a ir-50 model."""
53
+ model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54
+ return model
55
+
56
+
57
+ def IR_101(input_size):
58
+ """Constructs a ir-101 model."""
59
+ model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60
+ return model
61
+
62
+
63
+ def IR_152(input_size):
64
+ """Constructs a ir-152 model."""
65
+ model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66
+ return model
67
+
68
+
69
+ def IR_SE_50(input_size):
70
+ """Constructs a ir_se-50 model."""
71
+ model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72
+ return model
73
+
74
+
75
+ def IR_SE_101(input_size):
76
+ """Constructs a ir_se-101 model."""
77
+ model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78
+ return model
79
+
80
+
81
+ def IR_SE_152(input_size):
82
+ """Constructs a ir_se-152 model."""
83
+ model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84
+ return model
models/encoders/psp_encoders.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
6
+
7
+ from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
8
+ from models.stylegan2.model import EqualLinear
9
+
10
+
11
+ class GradualStyleBlock(Module):
12
+ def __init__(self, in_c, out_c, spatial, max_pooling=False):
13
+ super(GradualStyleBlock, self).__init__()
14
+ self.out_c = out_c
15
+ self.spatial = spatial
16
+ self.max_pooling = max_pooling
17
+ num_pools = int(np.log2(spatial))
18
+ modules = []
19
+ modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
20
+ nn.LeakyReLU()]
21
+ for i in range(num_pools - 1):
22
+ modules += [
23
+ Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
24
+ nn.LeakyReLU()
25
+ ]
26
+ self.convs = nn.Sequential(*modules)
27
+ self.linear = EqualLinear(out_c, out_c, lr_mul=1)
28
+
29
+ def forward(self, x):
30
+ x = self.convs(x)
31
+ # To make E accept more general H*W images, we add global average pooling to
32
+ # resize all features to 1*1*512 before mapping to latent codes
33
+ if self.max_pooling:
34
+ x = F.adaptive_max_pool2d(x, 1) ##### modified
35
+ else:
36
+ x = F.adaptive_avg_pool2d(x, 1) ##### modified
37
+ x = x.view(-1, self.out_c)
38
+ x = self.linear(x)
39
+ return x
40
+
41
+ class AdaptiveInstanceNorm(nn.Module):
42
+ def __init__(self, fin, style_dim=512):
43
+ super().__init__()
44
+
45
+ self.norm = nn.InstanceNorm2d(fin, affine=False)
46
+ self.style = nn.Linear(style_dim, fin * 2)
47
+
48
+ self.style.bias.data[:fin] = 1
49
+ self.style.bias.data[fin:] = 0
50
+
51
+ def forward(self, input, style):
52
+ style = self.style(style).unsqueeze(2).unsqueeze(3)
53
+ gamma, beta = style.chunk(2, 1)
54
+ out = self.norm(input)
55
+ out = gamma * out + beta
56
+ return out
57
+
58
+
59
+ class FusionLayer(Module): ##### modified
60
+ def __init__(self, inchannel, outchannel, use_skip_torgb=True, use_att=0):
61
+ super(FusionLayer, self).__init__()
62
+
63
+ self.transform = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=1, padding=1),
64
+ nn.LeakyReLU())
65
+ self.fusion_out = nn.Conv2d(outchannel*2, outchannel, kernel_size=3, stride=1, padding=1)
66
+ self.fusion_out.weight.data *= 0.01
67
+ self.fusion_out.weight[:,0:outchannel,1,1].data += torch.eye(outchannel)
68
+
69
+ self.use_skip_torgb = use_skip_torgb
70
+ if use_skip_torgb:
71
+ self.fusion_skip = nn.Conv2d(3+outchannel, 3, kernel_size=3, stride=1, padding=1)
72
+ self.fusion_skip.weight.data *= 0.01
73
+ self.fusion_skip.weight[:,0:3,1,1].data += torch.eye(3)
74
+
75
+ self.use_att = use_att
76
+ if use_att:
77
+ modules = []
78
+ modules.append(nn.Linear(512, outchannel))
79
+ for _ in range(use_att):
80
+ modules.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
81
+ modules.append(nn.Linear(outchannel, outchannel))
82
+ modules.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
83
+ self.linear = Sequential(*modules)
84
+ self.norm = AdaptiveInstanceNorm(outchannel*2, outchannel)
85
+ self.conv = nn.Conv2d(outchannel*2, 1, 3, 1, 1, bias=True)
86
+
87
+ def forward(self, feat, out, skip, editing_w=None):
88
+ x = self.transform(feat)
89
+ # similar to VToonify, use editing vector as condition
90
+ # fuse encoder feature and decoder feature with a predicted attention mask m_E
91
+ # if self.use_att = False, just fuse them with a simple conv layer
92
+ if self.use_att and editing_w is not None:
93
+ label = self.linear(editing_w)
94
+ m_E = (F.relu(self.conv(self.norm(torch.cat([out, abs(out-x)], dim=1), label)))).tanh()
95
+ x = x * m_E
96
+ out = self.fusion_out(torch.cat((out, x), dim=1))
97
+ if self.use_skip_torgb:
98
+ skip = self.fusion_skip(torch.cat((skip, x), dim=1))
99
+ return out, skip
100
+
101
+
102
+ class ResnetBlock(nn.Module):
103
+ def __init__(self, dim):
104
+ super(ResnetBlock, self).__init__()
105
+
106
+ self.conv_block = nn.Sequential(Conv2d(dim, dim, 3, 1, 1),
107
+ nn.LeakyReLU(),
108
+ Conv2d(dim, dim, 3, 1, 1))
109
+ self.relu = nn.LeakyReLU()
110
+
111
+ def forward(self, x):
112
+ out = x + self.conv_block(x)
113
+ return self.relu(out)
114
+
115
+ # trainable light-weight translation network T
116
+ # for sketch/mask-to-face translation,
117
+ # we add a trainable T to map y to an intermediate domain where E can more easily extract features.
118
+ class ResnetGenerator(nn.Module):
119
+ def __init__(self, in_channel=19, res_num=2):
120
+ super(ResnetGenerator, self).__init__()
121
+
122
+ modules = []
123
+ modules.append(Conv2d(in_channel, 16, 3, 2, 1))
124
+ modules.append(nn.LeakyReLU())
125
+ modules.append(Conv2d(16, 16, 3, 2, 1))
126
+ modules.append(nn.LeakyReLU())
127
+ for _ in range(res_num):
128
+ modules.append(ResnetBlock(16))
129
+ for _ in range(2):
130
+ modules.append(nn.ConvTranspose2d(16, 16, 3, 2, 1, output_padding=1))
131
+ modules.append(nn.LeakyReLU())
132
+ modules.append(Conv2d(16, 64, 3, 1, 1, bias=False))
133
+ modules.append(BatchNorm2d(64))
134
+ modules.append(PReLU(64))
135
+ self.model = Sequential(*modules)
136
+
137
+ def forward(self, input):
138
+ return self.model(input)
139
+
140
+ class GradualStyleEncoder(Module):
141
+ def __init__(self, num_layers, mode='ir', opts=None):
142
+ super(GradualStyleEncoder, self).__init__()
143
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
144
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
145
+ blocks = get_blocks(num_layers)
146
+ if mode == 'ir':
147
+ unit_module = bottleneck_IR
148
+ elif mode == 'ir_se':
149
+ unit_module = bottleneck_IR_SE
150
+
151
+ # for sketch/mask-to-face translation, add a new network T
152
+ if opts.input_nc != 3:
153
+ self.input_label_layer = ResnetGenerator(opts.input_nc, opts.res_num)
154
+
155
+ self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
156
+ BatchNorm2d(64),
157
+ PReLU(64))
158
+ modules = []
159
+ for block in blocks:
160
+ for bottleneck in block:
161
+ modules.append(unit_module(bottleneck.in_channel,
162
+ bottleneck.depth,
163
+ bottleneck.stride))
164
+ self.body = Sequential(*modules)
165
+
166
+ self.styles = nn.ModuleList()
167
+ self.style_count = opts.n_styles
168
+ self.coarse_ind = 3
169
+ self.middle_ind = 7
170
+ for i in range(self.style_count):
171
+ if i < self.coarse_ind:
172
+ style = GradualStyleBlock(512, 512, 16, 'max_pooling' in opts and opts.max_pooling)
173
+ elif i < self.middle_ind:
174
+ style = GradualStyleBlock(512, 512, 32, 'max_pooling' in opts and opts.max_pooling)
175
+ else:
176
+ style = GradualStyleBlock(512, 512, 64, 'max_pooling' in opts and opts.max_pooling)
177
+ self.styles.append(style)
178
+ self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
179
+ self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
180
+
181
+ # we concatenate pSp features in the middle layers and
182
+ # add a convolution layer to map the concatenated features to the first-layer input feature f of G.
183
+ self.featlayer = nn.Conv2d(768, 512, kernel_size=1, stride=1, padding=0) ##### modified
184
+ self.skiplayer = nn.Conv2d(768, 3, kernel_size=1, stride=1, padding=0) ##### modified
185
+
186
+ # skip connection
187
+ if 'use_skip' in opts and opts.use_skip: ##### modified
188
+ self.fusion = nn.ModuleList()
189
+ channels = [[256,512], [256,512], [256,512], [256,512], [128,512], [64,256], [64,128]]
190
+ # opts.skip_max_layer: how many layers are skipped to the decoder
191
+ for inc, outc in channels[:max(1, min(7, opts.skip_max_layer))]: # from 4 to 256
192
+ self.fusion.append(FusionLayer(inc, outc, opts.use_skip_torgb, opts.use_att))
193
+
194
+ def _upsample_add(self, x, y):
195
+ '''Upsample and add two feature maps.
196
+ Args:
197
+ x: (Variable) top feature map to be upsampled.
198
+ y: (Variable) lateral feature map.
199
+ Returns:
200
+ (Variable) added feature map.
201
+ Note in PyTorch, when input size is odd, the upsampled feature map
202
+ with `F.upsample(..., scale_factor=2, mode='nearest')`
203
+ maybe not equal to the lateral feature map size.
204
+ e.g.
205
+ original input size: [N,_,15,15] ->
206
+ conv2d feature map size: [N,_,8,8] ->
207
+ upsampled feature map size: [N,_,16,16]
208
+ So we choose bilinear upsample which supports arbitrary output sizes.
209
+ '''
210
+ _, _, H, W = y.size()
211
+ return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
212
+
213
+ # return_feat: return f
214
+ # return_full: return f and the skipped encoder features
215
+ # return [out, feats]
216
+ # out is the style latent code w+
217
+ # feats[0] is f for the 1st conv layer, feats[1] is f for the 1st torgb layer
218
+ # feats[2-8] is the skipped encoder features
219
+ def forward(self, x, return_feat=False, return_full=False): ##### modified
220
+ if x.shape[1] != 3:
221
+ x = self.input_label_layer(x)
222
+ else:
223
+ x = self.input_layer(x)
224
+ c256 = x ##### modified
225
+
226
+ latents = []
227
+ modulelist = list(self.body._modules.values())
228
+ for i, l in enumerate(modulelist):
229
+ x = l(x)
230
+ if i == 2: ##### modified
231
+ c128 = x
232
+ elif i == 6:
233
+ c1 = x
234
+ elif i == 10: ##### modified
235
+ c21 = x ##### modified
236
+ elif i == 15: ##### modified
237
+ c22 = x ##### modified
238
+ elif i == 20:
239
+ c2 = x
240
+ elif i == 23:
241
+ c3 = x
242
+
243
+ for j in range(self.coarse_ind):
244
+ latents.append(self.styles[j](c3))
245
+
246
+ p2 = self._upsample_add(c3, self.latlayer1(c2))
247
+ for j in range(self.coarse_ind, self.middle_ind):
248
+ latents.append(self.styles[j](p2))
249
+
250
+ p1 = self._upsample_add(p2, self.latlayer2(c1))
251
+ for j in range(self.middle_ind, self.style_count):
252
+ latents.append(self.styles[j](p1))
253
+
254
+ out = torch.stack(latents, dim=1)
255
+
256
+ if not return_feat:
257
+ return out
258
+
259
+ feats = [self.featlayer(torch.cat((c21, c22, c2), dim=1)), self.skiplayer(torch.cat((c21, c22, c2), dim=1))]
260
+
261
+ if return_full: ##### modified
262
+ feats += [c2, c2, c22, c21, c1, c128, c256]
263
+
264
+ return out, feats
265
+
266
+
267
+ # only compute the first-layer feature f
268
+ # E_F in the paper
269
+ def get_feat(self, x): ##### modified
270
+ # for sketch/mask-to-face translation
271
+ # use a trainable light-weight translation network T
272
+ if x.shape[1] != 3:
273
+ x = self.input_label_layer(x)
274
+ else:
275
+ x = self.input_layer(x)
276
+
277
+ latents = []
278
+ modulelist = list(self.body._modules.values())
279
+ for i, l in enumerate(modulelist):
280
+ x = l(x)
281
+ if i == 10: ##### modified
282
+ c21 = x ##### modified
283
+ elif i == 15: ##### modified
284
+ c22 = x ##### modified
285
+ elif i == 20:
286
+ c2 = x
287
+ break
288
+ return self.featlayer(torch.cat((c21, c22, c2), dim=1))
289
+
290
+ class BackboneEncoderUsingLastLayerIntoW(Module):
291
+ def __init__(self, num_layers, mode='ir', opts=None):
292
+ super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
293
+ print('Using BackboneEncoderUsingLastLayerIntoW')
294
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
295
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
296
+ blocks = get_blocks(num_layers)
297
+ if mode == 'ir':
298
+ unit_module = bottleneck_IR
299
+ elif mode == 'ir_se':
300
+ unit_module = bottleneck_IR_SE
301
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
302
+ BatchNorm2d(64),
303
+ PReLU(64))
304
+ self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
305
+ self.linear = EqualLinear(512, 512, lr_mul=1)
306
+ modules = []
307
+ for block in blocks:
308
+ for bottleneck in block:
309
+ modules.append(unit_module(bottleneck.in_channel,
310
+ bottleneck.depth,
311
+ bottleneck.stride))
312
+ self.body = Sequential(*modules)
313
+
314
+ def forward(self, x):
315
+ x = self.input_layer(x)
316
+ x = self.body(x)
317
+ x = self.output_pool(x)
318
+ x = x.view(-1, 512)
319
+ x = self.linear(x)
320
+ return x
321
+
322
+
323
+ class BackboneEncoderUsingLastLayerIntoWPlus(Module):
324
+ def __init__(self, num_layers, mode='ir', opts=None):
325
+ super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
326
+ print('Using BackboneEncoderUsingLastLayerIntoWPlus')
327
+ assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
328
+ assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
329
+ blocks = get_blocks(num_layers)
330
+ if mode == 'ir':
331
+ unit_module = bottleneck_IR
332
+ elif mode == 'ir_se':
333
+ unit_module = bottleneck_IR_SE
334
+ self.n_styles = opts.n_styles
335
+ self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
336
+ BatchNorm2d(64),
337
+ PReLU(64))
338
+ self.output_layer_2 = Sequential(BatchNorm2d(512),
339
+ torch.nn.AdaptiveAvgPool2d((7, 7)),
340
+ Flatten(),
341
+ Linear(512 * 7 * 7, 512))
342
+ self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1)
343
+ modules = []
344
+ for block in blocks:
345
+ for bottleneck in block:
346
+ modules.append(unit_module(bottleneck.in_channel,
347
+ bottleneck.depth,
348
+ bottleneck.stride))
349
+ self.body = Sequential(*modules)
350
+
351
+ def forward(self, x):
352
+ x = self.input_layer(x)
353
+ x = self.body(x)
354
+ x = self.output_layer_2(x)
355
+ x = self.linear(x)
356
+ x = x.view(-1, self.n_styles, 512)
357
+ return x
models/mtcnn/__init__.py ADDED
File without changes
models/mtcnn/mtcnn.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet
5
+ from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
6
+ from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage
7
+ from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face
8
+
9
+ device = 'cuda:0'
10
+
11
+
12
+ class MTCNN():
13
+ def __init__(self):
14
+ print(device)
15
+ self.pnet = PNet().to(device)
16
+ self.rnet = RNet().to(device)
17
+ self.onet = ONet().to(device)
18
+ self.pnet.eval()
19
+ self.rnet.eval()
20
+ self.onet.eval()
21
+ self.refrence = get_reference_facial_points(default_square=True)
22
+
23
+ def align(self, img):
24
+ _, landmarks = self.detect_faces(img)
25
+ if len(landmarks) == 0:
26
+ return None, None
27
+ facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)]
28
+ warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
29
+ return Image.fromarray(warped_face), tfm
30
+
31
+ def align_multi(self, img, limit=None, min_face_size=30.0):
32
+ boxes, landmarks = self.detect_faces(img, min_face_size)
33
+ if limit:
34
+ boxes = boxes[:limit]
35
+ landmarks = landmarks[:limit]
36
+ faces = []
37
+ tfms = []
38
+ for landmark in landmarks:
39
+ facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)]
40
+ warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
41
+ faces.append(Image.fromarray(warped_face))
42
+ tfms.append(tfm)
43
+ return boxes, faces, tfms
44
+
45
+ def detect_faces(self, image, min_face_size=20.0,
46
+ thresholds=[0.15, 0.25, 0.35],
47
+ nms_thresholds=[0.7, 0.7, 0.7]):
48
+ """
49
+ Arguments:
50
+ image: an instance of PIL.Image.
51
+ min_face_size: a float number.
52
+ thresholds: a list of length 3.
53
+ nms_thresholds: a list of length 3.
54
+
55
+ Returns:
56
+ two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
57
+ bounding boxes and facial landmarks.
58
+ """
59
+
60
+ # BUILD AN IMAGE PYRAMID
61
+ width, height = image.size
62
+ min_length = min(height, width)
63
+
64
+ min_detection_size = 12
65
+ factor = 0.707 # sqrt(0.5)
66
+
67
+ # scales for scaling the image
68
+ scales = []
69
+
70
+ # scales the image so that
71
+ # minimum size that we can detect equals to
72
+ # minimum face size that we want to detect
73
+ m = min_detection_size / min_face_size
74
+ min_length *= m
75
+
76
+ factor_count = 0
77
+ while min_length > min_detection_size:
78
+ scales.append(m * factor ** factor_count)
79
+ min_length *= factor
80
+ factor_count += 1
81
+
82
+ # STAGE 1
83
+
84
+ # it will be returned
85
+ bounding_boxes = []
86
+
87
+ with torch.no_grad():
88
+ # run P-Net on different scales
89
+ for s in scales:
90
+ boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0])
91
+ bounding_boxes.append(boxes)
92
+
93
+ # collect boxes (and offsets, and scores) from different scales
94
+ bounding_boxes = [i for i in bounding_boxes if i is not None]
95
+ bounding_boxes = np.vstack(bounding_boxes)
96
+
97
+ keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
98
+ bounding_boxes = bounding_boxes[keep]
99
+
100
+ # use offsets predicted by pnet to transform bounding boxes
101
+ bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
102
+ # shape [n_boxes, 5]
103
+
104
+ bounding_boxes = convert_to_square(bounding_boxes)
105
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
106
+
107
+ # STAGE 2
108
+
109
+ img_boxes = get_image_boxes(bounding_boxes, image, size=24)
110
+ img_boxes = torch.FloatTensor(img_boxes).to(device)
111
+
112
+ output = self.rnet(img_boxes)
113
+ offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4]
114
+ probs = output[1].cpu().data.numpy() # shape [n_boxes, 2]
115
+
116
+ keep = np.where(probs[:, 1] > thresholds[1])[0]
117
+ bounding_boxes = bounding_boxes[keep]
118
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
119
+ offsets = offsets[keep]
120
+
121
+ keep = nms(bounding_boxes, nms_thresholds[1])
122
+ bounding_boxes = bounding_boxes[keep]
123
+ bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
124
+ bounding_boxes = convert_to_square(bounding_boxes)
125
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
126
+
127
+ # STAGE 3
128
+
129
+ img_boxes = get_image_boxes(bounding_boxes, image, size=48)
130
+ if len(img_boxes) == 0:
131
+ return [], []
132
+ img_boxes = torch.FloatTensor(img_boxes).to(device)
133
+ output = self.onet(img_boxes)
134
+ landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10]
135
+ offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4]
136
+ probs = output[2].cpu().data.numpy() # shape [n_boxes, 2]
137
+
138
+ keep = np.where(probs[:, 1] > thresholds[2])[0]
139
+ bounding_boxes = bounding_boxes[keep]
140
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
141
+ offsets = offsets[keep]
142
+ landmarks = landmarks[keep]
143
+
144
+ # compute landmark points
145
+ width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
146
+ height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
147
+ xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
148
+ landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
149
+ landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
150
+
151
+ bounding_boxes = calibrate_box(bounding_boxes, offsets)
152
+ keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
153
+ bounding_boxes = bounding_boxes[keep]
154
+ landmarks = landmarks[keep]
155
+
156
+ return bounding_boxes, landmarks
models/mtcnn/mtcnn_pytorch/__init__.py ADDED
File without changes
models/mtcnn/mtcnn_pytorch/src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .visualization_utils import show_bboxes
2
+ from .detector import detect_faces
models/mtcnn/mtcnn_pytorch/src/align_trans.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Mon Apr 24 15:43:29 2017
4
+ @author: zhaoy
5
+ """
6
+ import numpy as np
7
+ import cv2
8
+
9
+ # from scipy.linalg import lstsq
10
+ # from scipy.ndimage import geometric_transform # , map_coordinates
11
+
12
+ from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2
13
+
14
+ # reference facial points, a list of coordinates (x,y)
15
+ REFERENCE_FACIAL_POINTS = [
16
+ [30.29459953, 51.69630051],
17
+ [65.53179932, 51.50139999],
18
+ [48.02519989, 71.73660278],
19
+ [33.54930115, 92.3655014],
20
+ [62.72990036, 92.20410156]
21
+ ]
22
+
23
+ DEFAULT_CROP_SIZE = (96, 112)
24
+
25
+
26
+ class FaceWarpException(Exception):
27
+ def __str__(self):
28
+ return 'In File {}:{}'.format(
29
+ __file__, super.__str__(self))
30
+
31
+
32
+ def get_reference_facial_points(output_size=None,
33
+ inner_padding_factor=0.0,
34
+ outer_padding=(0, 0),
35
+ default_square=False):
36
+ """
37
+ Function:
38
+ ----------
39
+ get reference 5 key points according to crop settings:
40
+ 0. Set default crop_size:
41
+ if default_square:
42
+ crop_size = (112, 112)
43
+ else:
44
+ crop_size = (96, 112)
45
+ 1. Pad the crop_size by inner_padding_factor in each side;
46
+ 2. Resize crop_size into (output_size - outer_padding*2),
47
+ pad into output_size with outer_padding;
48
+ 3. Output reference_5point;
49
+ Parameters:
50
+ ----------
51
+ @output_size: (w, h) or None
52
+ size of aligned face image
53
+ @inner_padding_factor: (w_factor, h_factor)
54
+ padding factor for inner (w, h)
55
+ @outer_padding: (w_pad, h_pad)
56
+ each row is a pair of coordinates (x, y)
57
+ @default_square: True or False
58
+ if True:
59
+ default crop_size = (112, 112)
60
+ else:
61
+ default crop_size = (96, 112);
62
+ !!! make sure, if output_size is not None:
63
+ (output_size - outer_padding)
64
+ = some_scale * (default crop_size * (1.0 + inner_padding_factor))
65
+ Returns:
66
+ ----------
67
+ @reference_5point: 5x2 np.array
68
+ each row is a pair of transformed coordinates (x, y)
69
+ """
70
+ # print('\n===> get_reference_facial_points():')
71
+
72
+ # print('---> Params:')
73
+ # print(' output_size: ', output_size)
74
+ # print(' inner_padding_factor: ', inner_padding_factor)
75
+ # print(' outer_padding:', outer_padding)
76
+ # print(' default_square: ', default_square)
77
+
78
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
79
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
80
+
81
+ # 0) make the inner region a square
82
+ if default_square:
83
+ size_diff = max(tmp_crop_size) - tmp_crop_size
84
+ tmp_5pts += size_diff / 2
85
+ tmp_crop_size += size_diff
86
+
87
+ # print('---> default:')
88
+ # print(' crop_size = ', tmp_crop_size)
89
+ # print(' reference_5pts = ', tmp_5pts)
90
+
91
+ if (output_size and
92
+ output_size[0] == tmp_crop_size[0] and
93
+ output_size[1] == tmp_crop_size[1]):
94
+ # print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
95
+ return tmp_5pts
96
+
97
+ if (inner_padding_factor == 0 and
98
+ outer_padding == (0, 0)):
99
+ if output_size is None:
100
+ # print('No paddings to do: return default reference points')
101
+ return tmp_5pts
102
+ else:
103
+ raise FaceWarpException(
104
+ 'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
105
+
106
+ # check output size
107
+ if not (0 <= inner_padding_factor <= 1.0):
108
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
109
+
110
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
111
+ and output_size is None):
112
+ output_size = tmp_crop_size * \
113
+ (1 + inner_padding_factor * 2).astype(np.int32)
114
+ output_size += np.array(outer_padding)
115
+ # print(' deduced from paddings, output_size = ', output_size)
116
+
117
+ if not (outer_padding[0] < output_size[0]
118
+ and outer_padding[1] < output_size[1]):
119
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
120
+ 'and outer_padding[1] < output_size[1])')
121
+
122
+ # 1) pad the inner region according inner_padding_factor
123
+ # print('---> STEP1: pad the inner region according inner_padding_factor')
124
+ if inner_padding_factor > 0:
125
+ size_diff = tmp_crop_size * inner_padding_factor * 2
126
+ tmp_5pts += size_diff / 2
127
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
128
+
129
+ # print(' crop_size = ', tmp_crop_size)
130
+ # print(' reference_5pts = ', tmp_5pts)
131
+
132
+ # 2) resize the padded inner region
133
+ # print('---> STEP2: resize the padded inner region')
134
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
135
+ # print(' crop_size = ', tmp_crop_size)
136
+ # print(' size_bf_outer_pad = ', size_bf_outer_pad)
137
+
138
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
139
+ raise FaceWarpException('Must have (output_size - outer_padding)'
140
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
141
+
142
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
143
+ # print(' resize scale_factor = ', scale_factor)
144
+ tmp_5pts = tmp_5pts * scale_factor
145
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
146
+ # tmp_5pts = tmp_5pts + size_diff / 2
147
+ tmp_crop_size = size_bf_outer_pad
148
+ # print(' crop_size = ', tmp_crop_size)
149
+ # print(' reference_5pts = ', tmp_5pts)
150
+
151
+ # 3) add outer_padding to make output_size
152
+ reference_5point = tmp_5pts + np.array(outer_padding)
153
+ tmp_crop_size = output_size
154
+ # print('---> STEP3: add outer_padding to make output_size')
155
+ # print(' crop_size = ', tmp_crop_size)
156
+ # print(' reference_5pts = ', tmp_5pts)
157
+
158
+ # print('===> end get_reference_facial_points\n')
159
+
160
+ return reference_5point
161
+
162
+
163
+ def get_affine_transform_matrix(src_pts, dst_pts):
164
+ """
165
+ Function:
166
+ ----------
167
+ get affine transform matrix 'tfm' from src_pts to dst_pts
168
+ Parameters:
169
+ ----------
170
+ @src_pts: Kx2 np.array
171
+ source points matrix, each row is a pair of coordinates (x, y)
172
+ @dst_pts: Kx2 np.array
173
+ destination points matrix, each row is a pair of coordinates (x, y)
174
+ Returns:
175
+ ----------
176
+ @tfm: 2x3 np.array
177
+ transform matrix from src_pts to dst_pts
178
+ """
179
+
180
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
181
+ n_pts = src_pts.shape[0]
182
+ ones = np.ones((n_pts, 1), src_pts.dtype)
183
+ src_pts_ = np.hstack([src_pts, ones])
184
+ dst_pts_ = np.hstack([dst_pts, ones])
185
+
186
+ # #print(('src_pts_:\n' + str(src_pts_))
187
+ # #print(('dst_pts_:\n' + str(dst_pts_))
188
+
189
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
190
+
191
+ # #print(('np.linalg.lstsq return A: \n' + str(A))
192
+ # #print(('np.linalg.lstsq return res: \n' + str(res))
193
+ # #print(('np.linalg.lstsq return rank: \n' + str(rank))
194
+ # #print(('np.linalg.lstsq return s: \n' + str(s))
195
+
196
+ if rank == 3:
197
+ tfm = np.float32([
198
+ [A[0, 0], A[1, 0], A[2, 0]],
199
+ [A[0, 1], A[1, 1], A[2, 1]]
200
+ ])
201
+ elif rank == 2:
202
+ tfm = np.float32([
203
+ [A[0, 0], A[1, 0], 0],
204
+ [A[0, 1], A[1, 1], 0]
205
+ ])
206
+
207
+ return tfm
208
+
209
+
210
+ def warp_and_crop_face(src_img,
211
+ facial_pts,
212
+ reference_pts=None,
213
+ crop_size=(96, 112),
214
+ align_type='smilarity'):
215
+ """
216
+ Function:
217
+ ----------
218
+ apply affine transform 'trans' to uv
219
+ Parameters:
220
+ ----------
221
+ @src_img: 3x3 np.array
222
+ input image
223
+ @facial_pts: could be
224
+ 1)a list of K coordinates (x,y)
225
+ or
226
+ 2) Kx2 or 2xK np.array
227
+ each row or col is a pair of coordinates (x, y)
228
+ @reference_pts: could be
229
+ 1) a list of K coordinates (x,y)
230
+ or
231
+ 2) Kx2 or 2xK np.array
232
+ each row or col is a pair of coordinates (x, y)
233
+ or
234
+ 3) None
235
+ if None, use default reference facial points
236
+ @crop_size: (w, h)
237
+ output face image size
238
+ @align_type: transform type, could be one of
239
+ 1) 'similarity': use similarity transform
240
+ 2) 'cv2_affine': use the first 3 points to do affine transform,
241
+ by calling cv2.getAffineTransform()
242
+ 3) 'affine': use all points to do affine transform
243
+ Returns:
244
+ ----------
245
+ @face_img: output face image with size (w, h) = @crop_size
246
+ """
247
+
248
+ if reference_pts is None:
249
+ if crop_size[0] == 96 and crop_size[1] == 112:
250
+ reference_pts = REFERENCE_FACIAL_POINTS
251
+ else:
252
+ default_square = False
253
+ inner_padding_factor = 0
254
+ outer_padding = (0, 0)
255
+ output_size = crop_size
256
+
257
+ reference_pts = get_reference_facial_points(output_size,
258
+ inner_padding_factor,
259
+ outer_padding,
260
+ default_square)
261
+
262
+ ref_pts = np.float32(reference_pts)
263
+ ref_pts_shp = ref_pts.shape
264
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
265
+ raise FaceWarpException(
266
+ 'reference_pts.shape must be (K,2) or (2,K) and K>2')
267
+
268
+ if ref_pts_shp[0] == 2:
269
+ ref_pts = ref_pts.T
270
+
271
+ src_pts = np.float32(facial_pts)
272
+ src_pts_shp = src_pts.shape
273
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
274
+ raise FaceWarpException(
275
+ 'facial_pts.shape must be (K,2) or (2,K) and K>2')
276
+
277
+ if src_pts_shp[0] == 2:
278
+ src_pts = src_pts.T
279
+
280
+ # #print('--->src_pts:\n', src_pts
281
+ # #print('--->ref_pts\n', ref_pts
282
+
283
+ if src_pts.shape != ref_pts.shape:
284
+ raise FaceWarpException(
285
+ 'facial_pts and reference_pts must have the same shape')
286
+
287
+ if align_type is 'cv2_affine':
288
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
289
+ # #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm))
290
+ elif align_type is 'affine':
291
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
292
+ # #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm))
293
+ else:
294
+ tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
295
+ # #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm))
296
+
297
+ # #print('--->Transform matrix: '
298
+ # #print(('type(tfm):' + str(type(tfm)))
299
+ # #print(('tfm.dtype:' + str(tfm.dtype))
300
+ # #print( tfm
301
+
302
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
303
+
304
+ return face_img, tfm
models/mtcnn/mtcnn_pytorch/src/box_utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+
5
+ def nms(boxes, overlap_threshold=0.5, mode='union'):
6
+ """Non-maximum suppression.
7
+
8
+ Arguments:
9
+ boxes: a float numpy array of shape [n, 5],
10
+ where each row is (xmin, ymin, xmax, ymax, score).
11
+ overlap_threshold: a float number.
12
+ mode: 'union' or 'min'.
13
+
14
+ Returns:
15
+ list with indices of the selected boxes
16
+ """
17
+
18
+ # if there are no boxes, return the empty list
19
+ if len(boxes) == 0:
20
+ return []
21
+
22
+ # list of picked indices
23
+ pick = []
24
+
25
+ # grab the coordinates of the bounding boxes
26
+ x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]
27
+
28
+ area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
29
+ ids = np.argsort(score) # in increasing order
30
+
31
+ while len(ids) > 0:
32
+
33
+ # grab index of the largest value
34
+ last = len(ids) - 1
35
+ i = ids[last]
36
+ pick.append(i)
37
+
38
+ # compute intersections
39
+ # of the box with the largest score
40
+ # with the rest of boxes
41
+
42
+ # left top corner of intersection boxes
43
+ ix1 = np.maximum(x1[i], x1[ids[:last]])
44
+ iy1 = np.maximum(y1[i], y1[ids[:last]])
45
+
46
+ # right bottom corner of intersection boxes
47
+ ix2 = np.minimum(x2[i], x2[ids[:last]])
48
+ iy2 = np.minimum(y2[i], y2[ids[:last]])
49
+
50
+ # width and height of intersection boxes
51
+ w = np.maximum(0.0, ix2 - ix1 + 1.0)
52
+ h = np.maximum(0.0, iy2 - iy1 + 1.0)
53
+
54
+ # intersections' areas
55
+ inter = w * h
56
+ if mode == 'min':
57
+ overlap = inter / np.minimum(area[i], area[ids[:last]])
58
+ elif mode == 'union':
59
+ # intersection over union (IoU)
60
+ overlap = inter / (area[i] + area[ids[:last]] - inter)
61
+
62
+ # delete all boxes where overlap is too big
63
+ ids = np.delete(
64
+ ids,
65
+ np.concatenate([[last], np.where(overlap > overlap_threshold)[0]])
66
+ )
67
+
68
+ return pick
69
+
70
+
71
+ def convert_to_square(bboxes):
72
+ """Convert bounding boxes to a square form.
73
+
74
+ Arguments:
75
+ bboxes: a float numpy array of shape [n, 5].
76
+
77
+ Returns:
78
+ a float numpy array of shape [n, 5],
79
+ squared bounding boxes.
80
+ """
81
+
82
+ square_bboxes = np.zeros_like(bboxes)
83
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
84
+ h = y2 - y1 + 1.0
85
+ w = x2 - x1 + 1.0
86
+ max_side = np.maximum(h, w)
87
+ square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
88
+ square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
89
+ square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
90
+ square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
91
+ return square_bboxes
92
+
93
+
94
+ def calibrate_box(bboxes, offsets):
95
+ """Transform bounding boxes to be more like true bounding boxes.
96
+ 'offsets' is one of the outputs of the nets.
97
+
98
+ Arguments:
99
+ bboxes: a float numpy array of shape [n, 5].
100
+ offsets: a float numpy array of shape [n, 4].
101
+
102
+ Returns:
103
+ a float numpy array of shape [n, 5].
104
+ """
105
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
106
+ w = x2 - x1 + 1.0
107
+ h = y2 - y1 + 1.0
108
+ w = np.expand_dims(w, 1)
109
+ h = np.expand_dims(h, 1)
110
+
111
+ # this is what happening here:
112
+ # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
113
+ # x1_true = x1 + tx1*w
114
+ # y1_true = y1 + ty1*h
115
+ # x2_true = x2 + tx2*w
116
+ # y2_true = y2 + ty2*h
117
+ # below is just more compact form of this
118
+
119
+ # are offsets always such that
120
+ # x1 < x2 and y1 < y2 ?
121
+
122
+ translation = np.hstack([w, h, w, h]) * offsets
123
+ bboxes[:, 0:4] = bboxes[:, 0:4] + translation
124
+ return bboxes
125
+
126
+
127
+ def get_image_boxes(bounding_boxes, img, size=24):
128
+ """Cut out boxes from the image.
129
+
130
+ Arguments:
131
+ bounding_boxes: a float numpy array of shape [n, 5].
132
+ img: an instance of PIL.Image.
133
+ size: an integer, size of cutouts.
134
+
135
+ Returns:
136
+ a float numpy array of shape [n, 3, size, size].
137
+ """
138
+
139
+ num_boxes = len(bounding_boxes)
140
+ width, height = img.size
141
+
142
+ [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height)
143
+ img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')
144
+
145
+ for i in range(num_boxes):
146
+ img_box = np.zeros((h[i], w[i], 3), 'uint8')
147
+
148
+ img_array = np.asarray(img, 'uint8')
149
+ img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \
150
+ img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]
151
+
152
+ # resize
153
+ img_box = Image.fromarray(img_box)
154
+ img_box = img_box.resize((size, size), Image.BILINEAR)
155
+ img_box = np.asarray(img_box, 'float32')
156
+
157
+ img_boxes[i, :, :, :] = _preprocess(img_box)
158
+
159
+ return img_boxes
160
+
161
+
162
+ def correct_bboxes(bboxes, width, height):
163
+ """Crop boxes that are too big and get coordinates
164
+ with respect to cutouts.
165
+
166
+ Arguments:
167
+ bboxes: a float numpy array of shape [n, 5],
168
+ where each row is (xmin, ymin, xmax, ymax, score).
169
+ width: a float number.
170
+ height: a float number.
171
+
172
+ Returns:
173
+ dy, dx, edy, edx: a int numpy arrays of shape [n],
174
+ coordinates of the boxes with respect to the cutouts.
175
+ y, x, ey, ex: a int numpy arrays of shape [n],
176
+ corrected ymin, xmin, ymax, xmax.
177
+ h, w: a int numpy arrays of shape [n],
178
+ just heights and widths of boxes.
179
+
180
+ in the following order:
181
+ [dy, edy, dx, edx, y, ey, x, ex, w, h].
182
+ """
183
+
184
+ x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
185
+ w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
186
+ num_boxes = bboxes.shape[0]
187
+
188
+ # 'e' stands for end
189
+ # (x, y) -> (ex, ey)
190
+ x, y, ex, ey = x1, y1, x2, y2
191
+
192
+ # we need to cut out a box from the image.
193
+ # (x, y, ex, ey) are corrected coordinates of the box
194
+ # in the image.
195
+ # (dx, dy, edx, edy) are coordinates of the box in the cutout
196
+ # from the image.
197
+ dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,))
198
+ edx, edy = w.copy() - 1.0, h.copy() - 1.0
199
+
200
+ # if box's bottom right corner is too far right
201
+ ind = np.where(ex > width - 1.0)[0]
202
+ edx[ind] = w[ind] + width - 2.0 - ex[ind]
203
+ ex[ind] = width - 1.0
204
+
205
+ # if box's bottom right corner is too low
206
+ ind = np.where(ey > height - 1.0)[0]
207
+ edy[ind] = h[ind] + height - 2.0 - ey[ind]
208
+ ey[ind] = height - 1.0
209
+
210
+ # if box's top left corner is too far left
211
+ ind = np.where(x < 0.0)[0]
212
+ dx[ind] = 0.0 - x[ind]
213
+ x[ind] = 0.0
214
+
215
+ # if box's top left corner is too high
216
+ ind = np.where(y < 0.0)[0]
217
+ dy[ind] = 0.0 - y[ind]
218
+ y[ind] = 0.0
219
+
220
+ return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
221
+ return_list = [i.astype('int32') for i in return_list]
222
+
223
+ return return_list
224
+
225
+
226
+ def _preprocess(img):
227
+ """Preprocessing step before feeding the network.
228
+
229
+ Arguments:
230
+ img: a float numpy array of shape [h, w, c].
231
+
232
+ Returns:
233
+ a float numpy array of shape [1, c, h, w].
234
+ """
235
+ img = img.transpose((2, 0, 1))
236
+ img = np.expand_dims(img, 0)
237
+ img = (img - 127.5) * 0.0078125
238
+ return img
models/mtcnn/mtcnn_pytorch/src/detector.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.autograd import Variable
4
+ from .get_nets import PNet, RNet, ONet
5
+ from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
6
+ from .first_stage import run_first_stage
7
+
8
+
9
+ def detect_faces(image, min_face_size=20.0,
10
+ thresholds=[0.6, 0.7, 0.8],
11
+ nms_thresholds=[0.7, 0.7, 0.7]):
12
+ """
13
+ Arguments:
14
+ image: an instance of PIL.Image.
15
+ min_face_size: a float number.
16
+ thresholds: a list of length 3.
17
+ nms_thresholds: a list of length 3.
18
+
19
+ Returns:
20
+ two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
21
+ bounding boxes and facial landmarks.
22
+ """
23
+
24
+ # LOAD MODELS
25
+ pnet = PNet()
26
+ rnet = RNet()
27
+ onet = ONet()
28
+ onet.eval()
29
+
30
+ # BUILD AN IMAGE PYRAMID
31
+ width, height = image.size
32
+ min_length = min(height, width)
33
+
34
+ min_detection_size = 12
35
+ factor = 0.707 # sqrt(0.5)
36
+
37
+ # scales for scaling the image
38
+ scales = []
39
+
40
+ # scales the image so that
41
+ # minimum size that we can detect equals to
42
+ # minimum face size that we want to detect
43
+ m = min_detection_size / min_face_size
44
+ min_length *= m
45
+
46
+ factor_count = 0
47
+ while min_length > min_detection_size:
48
+ scales.append(m * factor ** factor_count)
49
+ min_length *= factor
50
+ factor_count += 1
51
+
52
+ # STAGE 1
53
+
54
+ # it will be returned
55
+ bounding_boxes = []
56
+
57
+ with torch.no_grad():
58
+ # run P-Net on different scales
59
+ for s in scales:
60
+ boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0])
61
+ bounding_boxes.append(boxes)
62
+
63
+ # collect boxes (and offsets, and scores) from different scales
64
+ bounding_boxes = [i for i in bounding_boxes if i is not None]
65
+ bounding_boxes = np.vstack(bounding_boxes)
66
+
67
+ keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
68
+ bounding_boxes = bounding_boxes[keep]
69
+
70
+ # use offsets predicted by pnet to transform bounding boxes
71
+ bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
72
+ # shape [n_boxes, 5]
73
+
74
+ bounding_boxes = convert_to_square(bounding_boxes)
75
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
76
+
77
+ # STAGE 2
78
+
79
+ img_boxes = get_image_boxes(bounding_boxes, image, size=24)
80
+ img_boxes = torch.FloatTensor(img_boxes)
81
+
82
+ output = rnet(img_boxes)
83
+ offsets = output[0].data.numpy() # shape [n_boxes, 4]
84
+ probs = output[1].data.numpy() # shape [n_boxes, 2]
85
+
86
+ keep = np.where(probs[:, 1] > thresholds[1])[0]
87
+ bounding_boxes = bounding_boxes[keep]
88
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
89
+ offsets = offsets[keep]
90
+
91
+ keep = nms(bounding_boxes, nms_thresholds[1])
92
+ bounding_boxes = bounding_boxes[keep]
93
+ bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
94
+ bounding_boxes = convert_to_square(bounding_boxes)
95
+ bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
96
+
97
+ # STAGE 3
98
+
99
+ img_boxes = get_image_boxes(bounding_boxes, image, size=48)
100
+ if len(img_boxes) == 0:
101
+ return [], []
102
+ img_boxes = torch.FloatTensor(img_boxes)
103
+ output = onet(img_boxes)
104
+ landmarks = output[0].data.numpy() # shape [n_boxes, 10]
105
+ offsets = output[1].data.numpy() # shape [n_boxes, 4]
106
+ probs = output[2].data.numpy() # shape [n_boxes, 2]
107
+
108
+ keep = np.where(probs[:, 1] > thresholds[2])[0]
109
+ bounding_boxes = bounding_boxes[keep]
110
+ bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
111
+ offsets = offsets[keep]
112
+ landmarks = landmarks[keep]
113
+
114
+ # compute landmark points
115
+ width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
116
+ height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
117
+ xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
118
+ landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
119
+ landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
120
+
121
+ bounding_boxes = calibrate_box(bounding_boxes, offsets)
122
+ keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
123
+ bounding_boxes = bounding_boxes[keep]
124
+ landmarks = landmarks[keep]
125
+
126
+ return bounding_boxes, landmarks
models/mtcnn/mtcnn_pytorch/src/first_stage.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ import math
4
+ from PIL import Image
5
+ import numpy as np
6
+ from .box_utils import nms, _preprocess
7
+
8
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9
+ device = 'cuda:0'
10
+
11
+
12
+ def run_first_stage(image, net, scale, threshold):
13
+ """Run P-Net, generate bounding boxes, and do NMS.
14
+
15
+ Arguments:
16
+ image: an instance of PIL.Image.
17
+ net: an instance of pytorch's nn.Module, P-Net.
18
+ scale: a float number,
19
+ scale width and height of the image by this number.
20
+ threshold: a float number,
21
+ threshold on the probability of a face when generating
22
+ bounding boxes from predictions of the net.
23
+
24
+ Returns:
25
+ a float numpy array of shape [n_boxes, 9],
26
+ bounding boxes with scores and offsets (4 + 1 + 4).
27
+ """
28
+
29
+ # scale the image and convert it to a float array
30
+ width, height = image.size
31
+ sw, sh = math.ceil(width * scale), math.ceil(height * scale)
32
+ img = image.resize((sw, sh), Image.BILINEAR)
33
+ img = np.asarray(img, 'float32')
34
+
35
+ img = torch.FloatTensor(_preprocess(img)).to(device)
36
+ with torch.no_grad():
37
+ output = net(img)
38
+ probs = output[1].cpu().data.numpy()[0, 1, :, :]
39
+ offsets = output[0].cpu().data.numpy()
40
+ # probs: probability of a face at each sliding window
41
+ # offsets: transformations to true bounding boxes
42
+
43
+ boxes = _generate_bboxes(probs, offsets, scale, threshold)
44
+ if len(boxes) == 0:
45
+ return None
46
+
47
+ keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
48
+ return boxes[keep]
49
+
50
+
51
+ def _generate_bboxes(probs, offsets, scale, threshold):
52
+ """Generate bounding boxes at places
53
+ where there is probably a face.
54
+
55
+ Arguments:
56
+ probs: a float numpy array of shape [n, m].
57
+ offsets: a float numpy array of shape [1, 4, n, m].
58
+ scale: a float number,
59
+ width and height of the image were scaled by this number.
60
+ threshold: a float number.
61
+
62
+ Returns:
63
+ a float numpy array of shape [n_boxes, 9]
64
+ """
65
+
66
+ # applying P-Net is equivalent, in some sense, to
67
+ # moving 12x12 window with stride 2
68
+ stride = 2
69
+ cell_size = 12
70
+
71
+ # indices of boxes where there is probably a face
72
+ inds = np.where(probs > threshold)
73
+
74
+ if inds[0].size == 0:
75
+ return np.array([])
76
+
77
+ # transformations of bounding boxes
78
+ tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
79
+ # they are defined as:
80
+ # w = x2 - x1 + 1
81
+ # h = y2 - y1 + 1
82
+ # x1_true = x1 + tx1*w
83
+ # x2_true = x2 + tx2*w
84
+ # y1_true = y1 + ty1*h
85
+ # y2_true = y2 + ty2*h
86
+
87
+ offsets = np.array([tx1, ty1, tx2, ty2])
88
+ score = probs[inds[0], inds[1]]
89
+
90
+ # P-Net is applied to scaled images
91
+ # so we need to rescale bounding boxes back
92
+ bounding_boxes = np.vstack([
93
+ np.round((stride * inds[1] + 1.0) / scale),
94
+ np.round((stride * inds[0] + 1.0) / scale),
95
+ np.round((stride * inds[1] + 1.0 + cell_size) / scale),
96
+ np.round((stride * inds[0] + 1.0 + cell_size) / scale),
97
+ score, offsets
98
+ ])
99
+ # why one is added?
100
+
101
+ return bounding_boxes.T
models/mtcnn/mtcnn_pytorch/src/get_nets.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+
7
+ from configs.paths_config import model_paths
8
+ PNET_PATH = model_paths["mtcnn_pnet"]
9
+ ONET_PATH = model_paths["mtcnn_onet"]
10
+ RNET_PATH = model_paths["mtcnn_rnet"]
11
+
12
+
13
+ class Flatten(nn.Module):
14
+
15
+ def __init__(self):
16
+ super(Flatten, self).__init__()
17
+
18
+ def forward(self, x):
19
+ """
20
+ Arguments:
21
+ x: a float tensor with shape [batch_size, c, h, w].
22
+ Returns:
23
+ a float tensor with shape [batch_size, c*h*w].
24
+ """
25
+
26
+ # without this pretrained model isn't working
27
+ x = x.transpose(3, 2).contiguous()
28
+
29
+ return x.view(x.size(0), -1)
30
+
31
+
32
+ class PNet(nn.Module):
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+
37
+ # suppose we have input with size HxW, then
38
+ # after first layer: H - 2,
39
+ # after pool: ceil((H - 2)/2),
40
+ # after second conv: ceil((H - 2)/2) - 2,
41
+ # after last conv: ceil((H - 2)/2) - 4,
42
+ # and the same for W
43
+
44
+ self.features = nn.Sequential(OrderedDict([
45
+ ('conv1', nn.Conv2d(3, 10, 3, 1)),
46
+ ('prelu1', nn.PReLU(10)),
47
+ ('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)),
48
+
49
+ ('conv2', nn.Conv2d(10, 16, 3, 1)),
50
+ ('prelu2', nn.PReLU(16)),
51
+
52
+ ('conv3', nn.Conv2d(16, 32, 3, 1)),
53
+ ('prelu3', nn.PReLU(32))
54
+ ]))
55
+
56
+ self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
57
+ self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
58
+
59
+ weights = np.load(PNET_PATH, allow_pickle=True)[()]
60
+ for n, p in self.named_parameters():
61
+ p.data = torch.FloatTensor(weights[n])
62
+
63
+ def forward(self, x):
64
+ """
65
+ Arguments:
66
+ x: a float tensor with shape [batch_size, 3, h, w].
67
+ Returns:
68
+ b: a float tensor with shape [batch_size, 4, h', w'].
69
+ a: a float tensor with shape [batch_size, 2, h', w'].
70
+ """
71
+ x = self.features(x)
72
+ a = self.conv4_1(x)
73
+ b = self.conv4_2(x)
74
+ a = F.softmax(a, dim=-1)
75
+ return b, a
76
+
77
+
78
+ class RNet(nn.Module):
79
+
80
+ def __init__(self):
81
+ super().__init__()
82
+
83
+ self.features = nn.Sequential(OrderedDict([
84
+ ('conv1', nn.Conv2d(3, 28, 3, 1)),
85
+ ('prelu1', nn.PReLU(28)),
86
+ ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
87
+
88
+ ('conv2', nn.Conv2d(28, 48, 3, 1)),
89
+ ('prelu2', nn.PReLU(48)),
90
+ ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
91
+
92
+ ('conv3', nn.Conv2d(48, 64, 2, 1)),
93
+ ('prelu3', nn.PReLU(64)),
94
+
95
+ ('flatten', Flatten()),
96
+ ('conv4', nn.Linear(576, 128)),
97
+ ('prelu4', nn.PReLU(128))
98
+ ]))
99
+
100
+ self.conv5_1 = nn.Linear(128, 2)
101
+ self.conv5_2 = nn.Linear(128, 4)
102
+
103
+ weights = np.load(RNET_PATH, allow_pickle=True)[()]
104
+ for n, p in self.named_parameters():
105
+ p.data = torch.FloatTensor(weights[n])
106
+
107
+ def forward(self, x):
108
+ """
109
+ Arguments:
110
+ x: a float tensor with shape [batch_size, 3, h, w].
111
+ Returns:
112
+ b: a float tensor with shape [batch_size, 4].
113
+ a: a float tensor with shape [batch_size, 2].
114
+ """
115
+ x = self.features(x)
116
+ a = self.conv5_1(x)
117
+ b = self.conv5_2(x)
118
+ a = F.softmax(a, dim=-1)
119
+ return b, a
120
+
121
+
122
+ class ONet(nn.Module):
123
+
124
+ def __init__(self):
125
+ super().__init__()
126
+
127
+ self.features = nn.Sequential(OrderedDict([
128
+ ('conv1', nn.Conv2d(3, 32, 3, 1)),
129
+ ('prelu1', nn.PReLU(32)),
130
+ ('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
131
+
132
+ ('conv2', nn.Conv2d(32, 64, 3, 1)),
133
+ ('prelu2', nn.PReLU(64)),
134
+ ('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
135
+
136
+ ('conv3', nn.Conv2d(64, 64, 3, 1)),
137
+ ('prelu3', nn.PReLU(64)),
138
+ ('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
139
+
140
+ ('conv4', nn.Conv2d(64, 128, 2, 1)),
141
+ ('prelu4', nn.PReLU(128)),
142
+
143
+ ('flatten', Flatten()),
144
+ ('conv5', nn.Linear(1152, 256)),
145
+ ('drop5', nn.Dropout(0.25)),
146
+ ('prelu5', nn.PReLU(256)),
147
+ ]))
148
+
149
+ self.conv6_1 = nn.Linear(256, 2)
150
+ self.conv6_2 = nn.Linear(256, 4)
151
+ self.conv6_3 = nn.Linear(256, 10)
152
+
153
+ weights = np.load(ONET_PATH, allow_pickle=True)[()]
154
+ for n, p in self.named_parameters():
155
+ p.data = torch.FloatTensor(weights[n])
156
+
157
+ def forward(self, x):
158
+ """
159
+ Arguments:
160
+ x: a float tensor with shape [batch_size, 3, h, w].
161
+ Returns:
162
+ c: a float tensor with shape [batch_size, 10].
163
+ b: a float tensor with shape [batch_size, 4].
164
+ a: a float tensor with shape [batch_size, 2].
165
+ """
166
+ x = self.features(x)
167
+ a = self.conv6_1(x)
168
+ b = self.conv6_2(x)
169
+ c = self.conv6_3(x)
170
+ a = F.softmax(a, dim=-1)
171
+ return c, b, a
models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jul 11 06:54:28 2017
4
+
5
+ @author: zhaoyafei
6
+ """
7
+
8
+ import numpy as np
9
+ from numpy.linalg import inv, norm, lstsq
10
+ from numpy.linalg import matrix_rank as rank
11
+
12
+
13
+ class MatlabCp2tormException(Exception):
14
+ def __str__(self):
15
+ return 'In File {}:{}'.format(
16
+ __file__, super.__str__(self))
17
+
18
+
19
+ def tformfwd(trans, uv):
20
+ """
21
+ Function:
22
+ ----------
23
+ apply affine transform 'trans' to uv
24
+
25
+ Parameters:
26
+ ----------
27
+ @trans: 3x3 np.array
28
+ transform matrix
29
+ @uv: Kx2 np.array
30
+ each row is a pair of coordinates (x, y)
31
+
32
+ Returns:
33
+ ----------
34
+ @xy: Kx2 np.array
35
+ each row is a pair of transformed coordinates (x, y)
36
+ """
37
+ uv = np.hstack((
38
+ uv, np.ones((uv.shape[0], 1))
39
+ ))
40
+ xy = np.dot(uv, trans)
41
+ xy = xy[:, 0:-1]
42
+ return xy
43
+
44
+
45
+ def tforminv(trans, uv):
46
+ """
47
+ Function:
48
+ ----------
49
+ apply the inverse of affine transform 'trans' to uv
50
+
51
+ Parameters:
52
+ ----------
53
+ @trans: 3x3 np.array
54
+ transform matrix
55
+ @uv: Kx2 np.array
56
+ each row is a pair of coordinates (x, y)
57
+
58
+ Returns:
59
+ ----------
60
+ @xy: Kx2 np.array
61
+ each row is a pair of inverse-transformed coordinates (x, y)
62
+ """
63
+ Tinv = inv(trans)
64
+ xy = tformfwd(Tinv, uv)
65
+ return xy
66
+
67
+
68
+ def findNonreflectiveSimilarity(uv, xy, options=None):
69
+ options = {'K': 2}
70
+
71
+ K = options['K']
72
+ M = xy.shape[0]
73
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
74
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
75
+ # print('--->x, y:\n', x, y
76
+
77
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
78
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
79
+ X = np.vstack((tmp1, tmp2))
80
+ # print('--->X.shape: ', X.shape
81
+ # print('X:\n', X
82
+
83
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
84
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
85
+ U = np.vstack((u, v))
86
+ # print('--->U.shape: ', U.shape
87
+ # print('U:\n', U
88
+
89
+ # We know that X * r = U
90
+ if rank(X) >= 2 * K:
91
+ r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want
92
+ r = np.squeeze(r)
93
+ else:
94
+ raise Exception('cp2tform:twoUniquePointsReq')
95
+
96
+ # print('--->r:\n', r
97
+
98
+ sc = r[0]
99
+ ss = r[1]
100
+ tx = r[2]
101
+ ty = r[3]
102
+
103
+ Tinv = np.array([
104
+ [sc, -ss, 0],
105
+ [ss, sc, 0],
106
+ [tx, ty, 1]
107
+ ])
108
+
109
+ # print('--->Tinv:\n', Tinv
110
+
111
+ T = inv(Tinv)
112
+ # print('--->T:\n', T
113
+
114
+ T[:, 2] = np.array([0, 0, 1])
115
+
116
+ return T, Tinv
117
+
118
+
119
+ def findSimilarity(uv, xy, options=None):
120
+ options = {'K': 2}
121
+
122
+ # uv = np.array(uv)
123
+ # xy = np.array(xy)
124
+
125
+ # Solve for trans1
126
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
127
+
128
+ # Solve for trans2
129
+
130
+ # manually reflect the xy data across the Y-axis
131
+ xyR = xy
132
+ xyR[:, 0] = -1 * xyR[:, 0]
133
+
134
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
135
+
136
+ # manually reflect the tform to undo the reflection done on xyR
137
+ TreflectY = np.array([
138
+ [-1, 0, 0],
139
+ [0, 1, 0],
140
+ [0, 0, 1]
141
+ ])
142
+
143
+ trans2 = np.dot(trans2r, TreflectY)
144
+
145
+ # Figure out if trans1 or trans2 is better
146
+ xy1 = tformfwd(trans1, uv)
147
+ norm1 = norm(xy1 - xy)
148
+
149
+ xy2 = tformfwd(trans2, uv)
150
+ norm2 = norm(xy2 - xy)
151
+
152
+ if norm1 <= norm2:
153
+ return trans1, trans1_inv
154
+ else:
155
+ trans2_inv = inv(trans2)
156
+ return trans2, trans2_inv
157
+
158
+
159
+ def get_similarity_transform(src_pts, dst_pts, reflective=True):
160
+ """
161
+ Function:
162
+ ----------
163
+ Find Similarity Transform Matrix 'trans':
164
+ u = src_pts[:, 0]
165
+ v = src_pts[:, 1]
166
+ x = dst_pts[:, 0]
167
+ y = dst_pts[:, 1]
168
+ [x, y, 1] = [u, v, 1] * trans
169
+
170
+ Parameters:
171
+ ----------
172
+ @src_pts: Kx2 np.array
173
+ source points, each row is a pair of coordinates (x, y)
174
+ @dst_pts: Kx2 np.array
175
+ destination points, each row is a pair of transformed
176
+ coordinates (x, y)
177
+ @reflective: True or False
178
+ if True:
179
+ use reflective similarity transform
180
+ else:
181
+ use non-reflective similarity transform
182
+
183
+ Returns:
184
+ ----------
185
+ @trans: 3x3 np.array
186
+ transform matrix from uv to xy
187
+ trans_inv: 3x3 np.array
188
+ inverse of trans, transform matrix from xy to uv
189
+ """
190
+
191
+ if reflective:
192
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
193
+ else:
194
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
195
+
196
+ return trans, trans_inv
197
+
198
+
199
+ def cvt_tform_mat_for_cv2(trans):
200
+ """
201
+ Function:
202
+ ----------
203
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
204
+ directly used by cv2.warpAffine():
205
+ u = src_pts[:, 0]
206
+ v = src_pts[:, 1]
207
+ x = dst_pts[:, 0]
208
+ y = dst_pts[:, 1]
209
+ [x, y].T = cv_trans * [u, v, 1].T
210
+
211
+ Parameters:
212
+ ----------
213
+ @trans: 3x3 np.array
214
+ transform matrix from uv to xy
215
+
216
+ Returns:
217
+ ----------
218
+ @cv2_trans: 2x3 np.array
219
+ transform matrix from src_pts to dst_pts, could be directly used
220
+ for cv2.warpAffine()
221
+ """
222
+ cv2_trans = trans[:, 0:2].T
223
+
224
+ return cv2_trans
225
+
226
+
227
+ def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
228
+ """
229
+ Function:
230
+ ----------
231
+ Find Similarity Transform Matrix 'cv2_trans' which could be
232
+ directly used by cv2.warpAffine():
233
+ u = src_pts[:, 0]
234
+ v = src_pts[:, 1]
235
+ x = dst_pts[:, 0]
236
+ y = dst_pts[:, 1]
237
+ [x, y].T = cv_trans * [u, v, 1].T
238
+
239
+ Parameters:
240
+ ----------
241
+ @src_pts: Kx2 np.array
242
+ source points, each row is a pair of coordinates (x, y)
243
+ @dst_pts: Kx2 np.array
244
+ destination points, each row is a pair of transformed
245
+ coordinates (x, y)
246
+ reflective: True or False
247
+ if True:
248
+ use reflective similarity transform
249
+ else:
250
+ use non-reflective similarity transform
251
+
252
+ Returns:
253
+ ----------
254
+ @cv2_trans: 2x3 np.array
255
+ transform matrix from src_pts to dst_pts, could be directly used
256
+ for cv2.warpAffine()
257
+ """
258
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
259
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
260
+
261
+ return cv2_trans
262
+
263
+
264
+ if __name__ == '__main__':
265
+ """
266
+ u = [0, 6, -2]
267
+ v = [0, 3, 5]
268
+ x = [-1, 0, 4]
269
+ y = [-1, -10, 4]
270
+
271
+ # In Matlab, run:
272
+ #
273
+ # uv = [u'; v'];
274
+ # xy = [x'; y'];
275
+ # tform_sim=cp2tform(uv,xy,'similarity');
276
+ #
277
+ # trans = tform_sim.tdata.T
278
+ # ans =
279
+ # -0.0764 -1.6190 0
280
+ # 1.6190 -0.0764 0
281
+ # -3.2156 0.0290 1.0000
282
+ # trans_inv = tform_sim.tdata.Tinv
283
+ # ans =
284
+ #
285
+ # -0.0291 0.6163 0
286
+ # -0.6163 -0.0291 0
287
+ # -0.0756 1.9826 1.0000
288
+ # xy_m=tformfwd(tform_sim, u,v)
289
+ #
290
+ # xy_m =
291
+ #
292
+ # -3.2156 0.0290
293
+ # 1.1833 -9.9143
294
+ # 5.0323 2.8853
295
+ # uv_m=tforminv(tform_sim, x,y)
296
+ #
297
+ # uv_m =
298
+ #
299
+ # 0.5698 1.3953
300
+ # 6.0872 2.2733
301
+ # -2.6570 4.3314
302
+ """
303
+ u = [0, 6, -2]
304
+ v = [0, 3, 5]
305
+ x = [-1, 0, 4]
306
+ y = [-1, -10, 4]
307
+
308
+ uv = np.array((u, v)).T
309
+ xy = np.array((x, y)).T
310
+
311
+ print('\n--->uv:')
312
+ print(uv)
313
+ print('\n--->xy:')
314
+ print(xy)
315
+
316
+ trans, trans_inv = get_similarity_transform(uv, xy)
317
+
318
+ print('\n--->trans matrix:')
319
+ print(trans)
320
+
321
+ print('\n--->trans_inv matrix:')
322
+ print(trans_inv)
323
+
324
+ print('\n---> apply transform to uv')
325
+ print('\nxy_m = uv_augmented * trans')
326
+ uv_aug = np.hstack((
327
+ uv, np.ones((uv.shape[0], 1))
328
+ ))
329
+ xy_m = np.dot(uv_aug, trans)
330
+ print(xy_m)
331
+
332
+ print('\nxy_m = tformfwd(trans, uv)')
333
+ xy_m = tformfwd(trans, uv)
334
+ print(xy_m)
335
+
336
+ print('\n---> apply inverse transform to xy')
337
+ print('\nuv_m = xy_augmented * trans_inv')
338
+ xy_aug = np.hstack((
339
+ xy, np.ones((xy.shape[0], 1))
340
+ ))
341
+ uv_m = np.dot(xy_aug, trans_inv)
342
+ print(uv_m)
343
+
344
+ print('\nuv_m = tformfwd(trans_inv, xy)')
345
+ uv_m = tformfwd(trans_inv, xy)
346
+ print(uv_m)
347
+
348
+ uv_m = tforminv(trans, xy)
349
+ print('\nuv_m = tforminv(trans, xy)')
350
+ print(uv_m)
models/mtcnn/mtcnn_pytorch/src/visualization_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import ImageDraw
2
+
3
+
4
+ def show_bboxes(img, bounding_boxes, facial_landmarks=[]):
5
+ """Draw bounding boxes and facial landmarks.
6
+
7
+ Arguments:
8
+ img: an instance of PIL.Image.
9
+ bounding_boxes: a float numpy array of shape [n, 5].
10
+ facial_landmarks: a float numpy array of shape [n, 10].
11
+
12
+ Returns:
13
+ an instance of PIL.Image.
14
+ """
15
+
16
+ img_copy = img.copy()
17
+ draw = ImageDraw.Draw(img_copy)
18
+
19
+ for b in bounding_boxes:
20
+ draw.rectangle([
21
+ (b[0], b[1]), (b[2], b[3])
22
+ ], outline='white')
23
+
24
+ for p in facial_landmarks:
25
+ for i in range(5):
26
+ draw.ellipse([
27
+ (p[i] - 1.0, p[i + 5] - 1.0),
28
+ (p[i] + 1.0, p[i + 5] + 1.0)
29
+ ], outline='blue')
30
+
31
+ return img_copy
models/mtcnn/mtcnn_pytorch/src/weights/onet.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:313141c3646bebb73cb8350a2d5fee4c7f044fb96304b46ccc21aeea8b818f83
3
+ size 2345483
models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03e19e5c473932ab38f5a6308fe6210624006994a687e858d1dcda53c66f18cb
3
+ size 41271
models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5660aad67688edc9e8a3dd4e47ed120932835e06a8a711a423252a6f2c747083
3
+ size 604651
models/psp.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file defines the core research contribution
3
+ """
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ from models.encoders import psp_encoders
11
+ from models.stylegan2.model import Generator
12
+ from configs.paths_config import model_paths
13
+ import torch.nn.functional as F
14
+
15
+ def get_keys(d, name):
16
+ if 'state_dict' in d:
17
+ d = d['state_dict']
18
+ d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
19
+ return d_filt
20
+
21
+
22
+ class pSp(nn.Module):
23
+
24
+ def __init__(self, opts):
25
+ super(pSp, self).__init__()
26
+ self.set_opts(opts)
27
+ # compute number of style inputs based on the output resolution
28
+ self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
29
+ # Define architecture
30
+ self.encoder = self.set_encoder()
31
+ self.decoder = Generator(self.opts.output_size, 512, 8)
32
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
33
+ # Load weights if needed
34
+ self.load_weights()
35
+
36
+ def set_encoder(self):
37
+ if self.opts.encoder_type == 'GradualStyleEncoder':
38
+ encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
39
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
40
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
41
+ elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
42
+ encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
43
+ else:
44
+ raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
45
+ return encoder
46
+
47
+ def load_weights(self):
48
+ if self.opts.checkpoint_path is not None:
49
+ print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
50
+ ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
51
+ self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=False)
52
+ self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=False)
53
+ self.__load_latent_avg(ckpt)
54
+ else:
55
+ print('Loading encoders weights from irse50!')
56
+ encoder_ckpt = torch.load(model_paths['ir_se50'])
57
+ # if input to encoder is not an RGB image, do not load the input layer weights
58
+ if self.opts.label_nc != 0:
59
+ encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
60
+ self.encoder.load_state_dict(encoder_ckpt, strict=False)
61
+ print('Loading decoder weights from pretrained!')
62
+ ckpt = torch.load(self.opts.stylegan_weights)
63
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
64
+ if self.opts.learn_in_w:
65
+ self.__load_latent_avg(ckpt, repeat=1)
66
+ else:
67
+ self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)
68
+ # for video toonification, we load G0' model
69
+ if self.opts.toonify_weights is not None: ##### modified
70
+ ckpt = torch.load(self.opts.toonify_weights)
71
+ self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
72
+ self.opts.toonify_weights = None
73
+
74
+ # x1: image for first-layer feature f.
75
+ # x2: image for style latent code w+. If not specified, x2=x1.
76
+ # inject_latent: for sketch/mask-to-face translation, another latent code to fuse with w+
77
+ # latent_mask: fuse w+ and inject_latent with the mask (1~7 use w+ and 8~18 use inject_latent)
78
+ # use_feature: use f. Otherwise, use the orginal StyleGAN first-layer constant 4*4 feature
79
+ # first_layer_feature_ind: always=0, means the 1st layer of G accept f
80
+ # use_skip: use skip connection.
81
+ # zero_noise: use zero noises.
82
+ # editing_w: the editing vector v for video face editing
83
+ def forward(self, x1, x2=None, resize=True, latent_mask=None, randomize_noise=True,
84
+ inject_latent=None, return_latents=False, alpha=None, use_feature=True,
85
+ first_layer_feature_ind=0, use_skip=False, zero_noise=False, editing_w=None): ##### modified
86
+
87
+ feats = None # f and the skipped encoder features
88
+ codes, feats = self.encoder(x1, return_feat=True, return_full=use_skip) ##### modified
89
+ if x2 is not None: ##### modified
90
+ codes = self.encoder(x2) ##### modified
91
+ # normalize with respect to the center of an average face
92
+ if self.opts.start_from_latent_avg:
93
+ if self.opts.learn_in_w:
94
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
95
+ else:
96
+ codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
97
+
98
+ # E_W^{1:7}(T(x1)) concatenate E_W^{8:18}(w~)
99
+ if latent_mask is not None:
100
+ for i in latent_mask:
101
+ if inject_latent is not None:
102
+ if alpha is not None:
103
+ codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
104
+ else:
105
+ codes[:, i] = inject_latent[:, i]
106
+ else:
107
+ codes[:, i] = 0
108
+
109
+ first_layer_feats, skip_layer_feats, fusion = None, None, None ##### modified
110
+ if use_feature: ##### modified
111
+ first_layer_feats = feats[0:2] # use f
112
+ if use_skip: ##### modified
113
+ skip_layer_feats = feats[2:] # use skipped encoder feature
114
+ fusion = self.encoder.fusion # use fusion layer to fuse encoder feature and decoder feature.
115
+
116
+ images, result_latent = self.decoder([codes],
117
+ input_is_latent=True,
118
+ randomize_noise=randomize_noise,
119
+ return_latents=return_latents,
120
+ first_layer_feature=first_layer_feats,
121
+ first_layer_feature_ind=first_layer_feature_ind,
122
+ skip_layer_feature=skip_layer_feats,
123
+ fusion_block=fusion,
124
+ zero_noise=zero_noise,
125
+ editing_w=editing_w) ##### modified
126
+
127
+ if resize:
128
+ if self.opts.output_size == 1024: ##### modified
129
+ images = F.adaptive_avg_pool2d(images, (images.shape[2]//4, images.shape[3]//4)) ##### modified
130
+ else:
131
+ images = self.face_pool(images)
132
+
133
+ if return_latents:
134
+ return images, result_latent
135
+ else:
136
+ return images
137
+
138
+ def set_opts(self, opts):
139
+ self.opts = opts
140
+
141
+ def __load_latent_avg(self, ckpt, repeat=None):
142
+ if 'latent_avg' in ckpt:
143
+ self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
144
+ if repeat is not None:
145
+ self.latent_avg = self.latent_avg.repeat(repeat, 1)
146
+ else:
147
+ self.latent_avg = None
models/stylegan2/__init__.py ADDED
File without changes
models/stylegan2/lpips/__init__.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import numpy as np
7
+ #from skimage.measure import compare_ssim
8
+ from skimage.metrics import structural_similarity as compare_ssim
9
+ import torch
10
+ from torch.autograd import Variable
11
+
12
+ from models.stylegan2.lpips import dist_model
13
+
14
+ class PerceptualLoss(torch.nn.Module):
15
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
16
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
17
+ super(PerceptualLoss, self).__init__()
18
+ print('Setting up Perceptual loss...')
19
+ self.use_gpu = use_gpu
20
+ self.spatial = spatial
21
+ self.gpu_ids = gpu_ids
22
+ self.model = dist_model.DistModel()
23
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
24
+ print('...[%s] initialized'%self.model.name())
25
+ print('...Done')
26
+
27
+ def forward(self, pred, target, normalize=False):
28
+ """
29
+ Pred and target are Variables.
30
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
31
+ If normalize is False, assumes the images are already between [-1,+1]
32
+
33
+ Inputs pred and target are Nx3xHxW
34
+ Output pytorch Variable N long
35
+ """
36
+
37
+ if normalize:
38
+ target = 2 * target - 1
39
+ pred = 2 * pred - 1
40
+
41
+ return self.model.forward(target, pred)
42
+
43
+ def normalize_tensor(in_feat,eps=1e-10):
44
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
45
+ return in_feat/(norm_factor+eps)
46
+
47
+ def l2(p0, p1, range=255.):
48
+ return .5*np.mean((p0 / range - p1 / range)**2)
49
+
50
+ def psnr(p0, p1, peak=255.):
51
+ return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
52
+
53
+ def dssim(p0, p1, range=255.):
54
+ return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
55
+
56
+ def rgb2lab(in_img,mean_cent=False):
57
+ from skimage import color
58
+ img_lab = color.rgb2lab(in_img)
59
+ if(mean_cent):
60
+ img_lab[:,:,0] = img_lab[:,:,0]-50
61
+ return img_lab
62
+
63
+ def tensor2np(tensor_obj):
64
+ # change dimension of a tensor object into a numpy array
65
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
66
+
67
+ def np2tensor(np_obj):
68
+ # change dimenion of np array into tensor array
69
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
70
+
71
+ def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
72
+ # image tensor to lab tensor
73
+ from skimage import color
74
+
75
+ img = tensor2im(image_tensor)
76
+ img_lab = color.rgb2lab(img)
77
+ if(mc_only):
78
+ img_lab[:,:,0] = img_lab[:,:,0]-50
79
+ if(to_norm and not mc_only):
80
+ img_lab[:,:,0] = img_lab[:,:,0]-50
81
+ img_lab = img_lab/100.
82
+
83
+ return np2tensor(img_lab)
84
+
85
+ def tensorlab2tensor(lab_tensor,return_inbnd=False):
86
+ from skimage import color
87
+ import warnings
88
+ warnings.filterwarnings("ignore")
89
+
90
+ lab = tensor2np(lab_tensor)*100.
91
+ lab[:,:,0] = lab[:,:,0]+50
92
+
93
+ rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
94
+ if(return_inbnd):
95
+ # convert back to lab, see if we match
96
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
97
+ mask = 1.*np.isclose(lab_back,lab,atol=2.)
98
+ mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
99
+ return (im2tensor(rgb_back),mask)
100
+ else:
101
+ return im2tensor(rgb_back)
102
+
103
+ def rgb2lab(input):
104
+ from skimage import color
105
+ return color.rgb2lab(input / 255.)
106
+
107
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
108
+ image_numpy = image_tensor[0].cpu().float().numpy()
109
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
110
+ return image_numpy.astype(imtype)
111
+
112
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
113
+ return torch.Tensor((image / factor - cent)
114
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
115
+
116
+ def tensor2vec(vector_tensor):
117
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
118
+
119
+ def voc_ap(rec, prec, use_07_metric=False):
120
+ """ ap = voc_ap(rec, prec, [use_07_metric])
121
+ Compute VOC AP given precision and recall.
122
+ If use_07_metric is true, uses the
123
+ VOC 07 11 point method (default:False).
124
+ """
125
+ if use_07_metric:
126
+ # 11 point metric
127
+ ap = 0.
128
+ for t in np.arange(0., 1.1, 0.1):
129
+ if np.sum(rec >= t) == 0:
130
+ p = 0
131
+ else:
132
+ p = np.max(prec[rec >= t])
133
+ ap = ap + p / 11.
134
+ else:
135
+ # correct AP calculation
136
+ # first append sentinel values at the end
137
+ mrec = np.concatenate(([0.], rec, [1.]))
138
+ mpre = np.concatenate(([0.], prec, [0.]))
139
+
140
+ # compute the precision envelope
141
+ for i in range(mpre.size - 1, 0, -1):
142
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
143
+
144
+ # to calculate area under PR curve, look for points
145
+ # where X axis (recall) changes value
146
+ i = np.where(mrec[1:] != mrec[:-1])[0]
147
+
148
+ # and sum (\Delta recall) * prec
149
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
150
+ return ap
151
+
152
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
153
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
154
+ image_numpy = image_tensor[0].cpu().float().numpy()
155
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
156
+ return image_numpy.astype(imtype)
157
+
158
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
159
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
160
+ return torch.Tensor((image / factor - cent)
161
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
models/stylegan2/lpips/base_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+ from pdb import set_trace as st
6
+ from IPython import embed
7
+
8
+ class BaseModel():
9
+ def __init__(self):
10
+ pass;
11
+
12
+ def name(self):
13
+ return 'BaseModel'
14
+
15
+ def initialize(self, use_gpu=True, gpu_ids=[0]):
16
+ self.use_gpu = use_gpu
17
+ self.gpu_ids = gpu_ids
18
+
19
+ def forward(self):
20
+ pass
21
+
22
+ def get_image_paths(self):
23
+ pass
24
+
25
+ def optimize_parameters(self):
26
+ pass
27
+
28
+ def get_current_visuals(self):
29
+ return self.input
30
+
31
+ def get_current_errors(self):
32
+ return {}
33
+
34
+ def save(self, label):
35
+ pass
36
+
37
+ # helper saving function that can be used by subclasses
38
+ def save_network(self, network, path, network_label, epoch_label):
39
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40
+ save_path = os.path.join(path, save_filename)
41
+ torch.save(network.state_dict(), save_path)
42
+
43
+ # helper loading function that can be used by subclasses
44
+ def load_network(self, network, network_label, epoch_label):
45
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46
+ save_path = os.path.join(self.save_dir, save_filename)
47
+ print('Loading network from %s'%save_path)
48
+ network.load_state_dict(torch.load(save_path))
49
+
50
+ def update_learning_rate():
51
+ pass
52
+
53
+ def get_image_paths(self):
54
+ return self.image_paths
55
+
56
+ def save_done(self, flag=False):
57
+ np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
models/stylegan2/lpips/dist_model.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import os
9
+ from collections import OrderedDict
10
+ from torch.autograd import Variable
11
+ import itertools
12
+ from models.stylegan2.lpips.base_model import BaseModel
13
+ from scipy.ndimage import zoom
14
+ import fractions
15
+ import functools
16
+ import skimage.transform
17
+ from tqdm import tqdm
18
+
19
+ from IPython import embed
20
+
21
+ from models.stylegan2.lpips import networks_basic as networks
22
+ import models.stylegan2.lpips as util
23
+
24
+ class DistModel(BaseModel):
25
+ def name(self):
26
+ return self.model_name
27
+
28
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
29
+ use_gpu=True, printNet=False, spatial=False,
30
+ is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
31
+ '''
32
+ INPUTS
33
+ model - ['net-lin'] for linearly calibrated network
34
+ ['net'] for off-the-shelf network
35
+ ['L2'] for L2 distance in Lab colorspace
36
+ ['SSIM'] for ssim in RGB colorspace
37
+ net - ['squeeze','alex','vgg']
38
+ model_path - if None, will look in weights/[NET_NAME].pth
39
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
40
+ use_gpu - bool - whether or not to use a GPU
41
+ printNet - bool - whether or not to print network architecture out
42
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
43
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
44
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
45
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
46
+ is_train - bool - [True] for training mode
47
+ lr - float - initial learning rate
48
+ beta1 - float - initial momentum term for adam
49
+ version - 0.1 for latest, 0.0 was original (with a bug)
50
+ gpu_ids - int array - [0] by default, gpus to use
51
+ '''
52
+ BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
53
+
54
+ self.model = model
55
+ self.net = net
56
+ self.is_train = is_train
57
+ self.spatial = spatial
58
+ self.gpu_ids = gpu_ids
59
+ self.model_name = '%s [%s]'%(model,net)
60
+
61
+ if(self.model == 'net-lin'): # pretrained net + linear layer
62
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
63
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
64
+ kw = {}
65
+ if not use_gpu:
66
+ kw['map_location'] = 'cpu'
67
+ if(model_path is None):
68
+ import inspect
69
+ model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
70
+
71
+ if(not is_train):
72
+ print('Loading model from: %s'%model_path)
73
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
74
+
75
+ elif(self.model=='net'): # pretrained network
76
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
77
+ elif(self.model in ['L2','l2']):
78
+ self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
79
+ self.model_name = 'L2'
80
+ elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
81
+ self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
82
+ self.model_name = 'SSIM'
83
+ else:
84
+ raise ValueError("Model [%s] not recognized." % self.model)
85
+
86
+ self.parameters = list(self.net.parameters())
87
+
88
+ if self.is_train: # training mode
89
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
90
+ self.rankLoss = networks.BCERankingLoss()
91
+ self.parameters += list(self.rankLoss.net.parameters())
92
+ self.lr = lr
93
+ self.old_lr = lr
94
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
95
+ else: # test mode
96
+ self.net.eval()
97
+
98
+ if(use_gpu):
99
+ self.net.to(gpu_ids[0])
100
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
101
+ if(self.is_train):
102
+ self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
103
+
104
+ if(printNet):
105
+ print('---------- Networks initialized -------------')
106
+ networks.print_network(self.net)
107
+ print('-----------------------------------------------')
108
+
109
+ def forward(self, in0, in1, retPerLayer=False):
110
+ ''' Function computes the distance between image patches in0 and in1
111
+ INPUTS
112
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
113
+ OUTPUT
114
+ computed distances between in0 and in1
115
+ '''
116
+
117
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
118
+
119
+ # ***** TRAINING FUNCTIONS *****
120
+ def optimize_parameters(self):
121
+ self.forward_train()
122
+ self.optimizer_net.zero_grad()
123
+ self.backward_train()
124
+ self.optimizer_net.step()
125
+ self.clamp_weights()
126
+
127
+ def clamp_weights(self):
128
+ for module in self.net.modules():
129
+ if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
130
+ module.weight.data = torch.clamp(module.weight.data,min=0)
131
+
132
+ def set_input(self, data):
133
+ self.input_ref = data['ref']
134
+ self.input_p0 = data['p0']
135
+ self.input_p1 = data['p1']
136
+ self.input_judge = data['judge']
137
+
138
+ if(self.use_gpu):
139
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
140
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
141
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
142
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
143
+
144
+ self.var_ref = Variable(self.input_ref,requires_grad=True)
145
+ self.var_p0 = Variable(self.input_p0,requires_grad=True)
146
+ self.var_p1 = Variable(self.input_p1,requires_grad=True)
147
+
148
+ def forward_train(self): # run forward pass
149
+ # print(self.net.module.scaling_layer.shift)
150
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
151
+
152
+ self.d0 = self.forward(self.var_ref, self.var_p0)
153
+ self.d1 = self.forward(self.var_ref, self.var_p1)
154
+ self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
155
+
156
+ self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
157
+
158
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
159
+
160
+ return self.loss_total
161
+
162
+ def backward_train(self):
163
+ torch.mean(self.loss_total).backward()
164
+
165
+ def compute_accuracy(self,d0,d1,judge):
166
+ ''' d0, d1 are Variables, judge is a Tensor '''
167
+ d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
168
+ judge_per = judge.cpu().numpy().flatten()
169
+ return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
170
+
171
+ def get_current_errors(self):
172
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
173
+ ('acc_r', self.acc_r)])
174
+
175
+ for key in retDict.keys():
176
+ retDict[key] = np.mean(retDict[key])
177
+
178
+ return retDict
179
+
180
+ def get_current_visuals(self):
181
+ zoom_factor = 256/self.var_ref.data.size()[2]
182
+
183
+ ref_img = util.tensor2im(self.var_ref.data)
184
+ p0_img = util.tensor2im(self.var_p0.data)
185
+ p1_img = util.tensor2im(self.var_p1.data)
186
+
187
+ ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
188
+ p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
189
+ p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
190
+
191
+ return OrderedDict([('ref', ref_img_vis),
192
+ ('p0', p0_img_vis),
193
+ ('p1', p1_img_vis)])
194
+
195
+ def save(self, path, label):
196
+ if(self.use_gpu):
197
+ self.save_network(self.net.module, path, '', label)
198
+ else:
199
+ self.save_network(self.net, path, '', label)
200
+ self.save_network(self.rankLoss.net, path, 'rank', label)
201
+
202
+ def update_learning_rate(self,nepoch_decay):
203
+ lrd = self.lr / nepoch_decay
204
+ lr = self.old_lr - lrd
205
+
206
+ for param_group in self.optimizer_net.param_groups:
207
+ param_group['lr'] = lr
208
+
209
+ print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
210
+ self.old_lr = lr
211
+
212
+ def score_2afc_dataset(data_loader, func, name=''):
213
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
214
+ distance function 'func' in dataset 'data_loader'
215
+ INPUTS
216
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
217
+ func - callable distance function - calling d=func(in0,in1) should take 2
218
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
219
+ OUTPUTS
220
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
221
+ [1] - dictionary with following elements
222
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
223
+ gts - N array in [0,1], preferred patch selected by human evaluators
224
+ (closer to "0" for left patch p0, "1" for right patch p1,
225
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
226
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
227
+ CONSTS
228
+ N - number of test triplets in data_loader
229
+ '''
230
+
231
+ d0s = []
232
+ d1s = []
233
+ gts = []
234
+
235
+ for data in tqdm(data_loader.load_data(), desc=name):
236
+ d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
237
+ d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
238
+ gts+=data['judge'].cpu().numpy().flatten().tolist()
239
+
240
+ d0s = np.array(d0s)
241
+ d1s = np.array(d1s)
242
+ gts = np.array(gts)
243
+ scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
244
+
245
+ return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
246
+
247
+ def score_jnd_dataset(data_loader, func, name=''):
248
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
249
+ INPUTS
250
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
251
+ func - callable distance function - calling d=func(in0,in1) should take 2
252
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
253
+ OUTPUTS
254
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
255
+ [1] - dictionary with following elements
256
+ ds - N array containing distances between two patches shown to human evaluator
257
+ sames - N array containing fraction of people who thought the two patches were identical
258
+ CONSTS
259
+ N - number of test triplets in data_loader
260
+ '''
261
+
262
+ ds = []
263
+ gts = []
264
+
265
+ for data in tqdm(data_loader.load_data(), desc=name):
266
+ ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
267
+ gts+=data['same'].cpu().numpy().flatten().tolist()
268
+
269
+ sames = np.array(gts)
270
+ ds = np.array(ds)
271
+
272
+ sorted_inds = np.argsort(ds)
273
+ ds_sorted = ds[sorted_inds]
274
+ sames_sorted = sames[sorted_inds]
275
+
276
+ TPs = np.cumsum(sames_sorted)
277
+ FPs = np.cumsum(1-sames_sorted)
278
+ FNs = np.sum(sames_sorted)-TPs
279
+
280
+ precs = TPs/(TPs+FPs)
281
+ recs = TPs/(TPs+FNs)
282
+ score = util.voc_ap(recs,precs)
283
+
284
+ return(score, dict(ds=ds,sames=sames))
models/stylegan2/lpips/networks_basic.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+ from torch.autograd import Variable
9
+ import numpy as np
10
+ from pdb import set_trace as st
11
+ from skimage import color
12
+ from IPython import embed
13
+ from models.stylegan2.lpips import pretrained_networks as pn
14
+
15
+ import models.stylegan2.lpips as util
16
+
17
+ def spatial_average(in_tens, keepdim=True):
18
+ return in_tens.mean([2,3],keepdim=keepdim)
19
+
20
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
21
+ in_H = in_tens.shape[2]
22
+ scale_factor = 1.*out_H/in_H
23
+
24
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
25
+
26
+ # Learned perceptual metric
27
+ class PNetLin(nn.Module):
28
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
29
+ super(PNetLin, self).__init__()
30
+
31
+ self.pnet_type = pnet_type
32
+ self.pnet_tune = pnet_tune
33
+ self.pnet_rand = pnet_rand
34
+ self.spatial = spatial
35
+ self.lpips = lpips
36
+ self.version = version
37
+ self.scaling_layer = ScalingLayer()
38
+
39
+ if(self.pnet_type in ['vgg','vgg16']):
40
+ net_type = pn.vgg16
41
+ self.chns = [64,128,256,512,512]
42
+ elif(self.pnet_type=='alex'):
43
+ net_type = pn.alexnet
44
+ self.chns = [64,192,384,256,256]
45
+ elif(self.pnet_type=='squeeze'):
46
+ net_type = pn.squeezenet
47
+ self.chns = [64,128,256,384,384,512,512]
48
+ self.L = len(self.chns)
49
+
50
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
51
+
52
+ if(lpips):
53
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
54
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
55
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
56
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
57
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
58
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
59
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
60
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
61
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
62
+ self.lins+=[self.lin5,self.lin6]
63
+
64
+ def forward(self, in0, in1, retPerLayer=False):
65
+ # v0.0 - original release had a bug, where input was not scaled
66
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
67
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
68
+ feats0, feats1, diffs = {}, {}, {}
69
+
70
+ for kk in range(self.L):
71
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
72
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
73
+
74
+ if(self.lpips):
75
+ if(self.spatial):
76
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
77
+ else:
78
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
79
+ else:
80
+ if(self.spatial):
81
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
82
+ else:
83
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
84
+
85
+ val = res[0]
86
+ for l in range(1,self.L):
87
+ val += res[l]
88
+
89
+ if(retPerLayer):
90
+ return (val, res)
91
+ else:
92
+ return val
93
+
94
+ class ScalingLayer(nn.Module):
95
+ def __init__(self):
96
+ super(ScalingLayer, self).__init__()
97
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
98
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
99
+
100
+ def forward(self, inp):
101
+ return (inp - self.shift) / self.scale
102
+
103
+
104
+ class NetLinLayer(nn.Module):
105
+ ''' A single linear layer which does a 1x1 conv '''
106
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
107
+ super(NetLinLayer, self).__init__()
108
+
109
+ layers = [nn.Dropout(),] if(use_dropout) else []
110
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
111
+ self.model = nn.Sequential(*layers)
112
+
113
+
114
+ class Dist2LogitLayer(nn.Module):
115
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
116
+ def __init__(self, chn_mid=32, use_sigmoid=True):
117
+ super(Dist2LogitLayer, self).__init__()
118
+
119
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
120
+ layers += [nn.LeakyReLU(0.2,True),]
121
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
122
+ layers += [nn.LeakyReLU(0.2,True),]
123
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
124
+ if(use_sigmoid):
125
+ layers += [nn.Sigmoid(),]
126
+ self.model = nn.Sequential(*layers)
127
+
128
+ def forward(self,d0,d1,eps=0.1):
129
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
130
+
131
+ class BCERankingLoss(nn.Module):
132
+ def __init__(self, chn_mid=32):
133
+ super(BCERankingLoss, self).__init__()
134
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
135
+ # self.parameters = list(self.net.parameters())
136
+ self.loss = torch.nn.BCELoss()
137
+
138
+ def forward(self, d0, d1, judge):
139
+ per = (judge+1.)/2.
140
+ self.logit = self.net.forward(d0,d1)
141
+ return self.loss(self.logit, per)
142
+
143
+ # L2, DSSIM metrics
144
+ class FakeNet(nn.Module):
145
+ def __init__(self, use_gpu=True, colorspace='Lab'):
146
+ super(FakeNet, self).__init__()
147
+ self.use_gpu = use_gpu
148
+ self.colorspace=colorspace
149
+
150
+ class L2(FakeNet):
151
+
152
+ def forward(self, in0, in1, retPerLayer=None):
153
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
154
+
155
+ if(self.colorspace=='RGB'):
156
+ (N,C,X,Y) = in0.size()
157
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
158
+ return value
159
+ elif(self.colorspace=='Lab'):
160
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
161
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
162
+ ret_var = Variable( torch.Tensor((value,) ) )
163
+ if(self.use_gpu):
164
+ ret_var = ret_var.cuda()
165
+ return ret_var
166
+
167
+ class DSSIM(FakeNet):
168
+
169
+ def forward(self, in0, in1, retPerLayer=None):
170
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
171
+
172
+ if(self.colorspace=='RGB'):
173
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
174
+ elif(self.colorspace=='Lab'):
175
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
176
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
177
+ ret_var = Variable( torch.Tensor((value,) ) )
178
+ if(self.use_gpu):
179
+ ret_var = ret_var.cuda()
180
+ return ret_var
181
+
182
+ def print_network(net):
183
+ num_params = 0
184
+ for param in net.parameters():
185
+ num_params += param.numel()
186
+ print('Network',net)
187
+ print('Total number of parameters: %d' % num_params)
models/stylegan2/lpips/pretrained_networks.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+ from IPython import embed
5
+
6
+ class squeezenet(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, pretrained=True):
8
+ super(squeezenet, self).__init__()
9
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ self.slice7 = torch.nn.Sequential()
17
+ self.N_slices = 7
18
+ for x in range(2):
19
+ self.slice1.add_module(str(x), pretrained_features[x])
20
+ for x in range(2,5):
21
+ self.slice2.add_module(str(x), pretrained_features[x])
22
+ for x in range(5, 8):
23
+ self.slice3.add_module(str(x), pretrained_features[x])
24
+ for x in range(8, 10):
25
+ self.slice4.add_module(str(x), pretrained_features[x])
26
+ for x in range(10, 11):
27
+ self.slice5.add_module(str(x), pretrained_features[x])
28
+ for x in range(11, 12):
29
+ self.slice6.add_module(str(x), pretrained_features[x])
30
+ for x in range(12, 13):
31
+ self.slice7.add_module(str(x), pretrained_features[x])
32
+ if not requires_grad:
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ def forward(self, X):
37
+ h = self.slice1(X)
38
+ h_relu1 = h
39
+ h = self.slice2(h)
40
+ h_relu2 = h
41
+ h = self.slice3(h)
42
+ h_relu3 = h
43
+ h = self.slice4(h)
44
+ h_relu4 = h
45
+ h = self.slice5(h)
46
+ h_relu5 = h
47
+ h = self.slice6(h)
48
+ h_relu6 = h
49
+ h = self.slice7(h)
50
+ h_relu7 = h
51
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53
+
54
+ return out
55
+
56
+
57
+ class alexnet(torch.nn.Module):
58
+ def __init__(self, requires_grad=False, pretrained=True):
59
+ super(alexnet, self).__init__()
60
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61
+ self.slice1 = torch.nn.Sequential()
62
+ self.slice2 = torch.nn.Sequential()
63
+ self.slice3 = torch.nn.Sequential()
64
+ self.slice4 = torch.nn.Sequential()
65
+ self.slice5 = torch.nn.Sequential()
66
+ self.N_slices = 5
67
+ for x in range(2):
68
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69
+ for x in range(2, 5):
70
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(5, 8):
72
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(8, 10):
74
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(10, 12):
76
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77
+ if not requires_grad:
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, X):
82
+ h = self.slice1(X)
83
+ h_relu1 = h
84
+ h = self.slice2(h)
85
+ h_relu2 = h
86
+ h = self.slice3(h)
87
+ h_relu3 = h
88
+ h = self.slice4(h)
89
+ h_relu4 = h
90
+ h = self.slice5(h)
91
+ h_relu5 = h
92
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94
+
95
+ return out
96
+
97
+ class vgg16(torch.nn.Module):
98
+ def __init__(self, requires_grad=False, pretrained=True):
99
+ super(vgg16, self).__init__()
100
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101
+ self.slice1 = torch.nn.Sequential()
102
+ self.slice2 = torch.nn.Sequential()
103
+ self.slice3 = torch.nn.Sequential()
104
+ self.slice4 = torch.nn.Sequential()
105
+ self.slice5 = torch.nn.Sequential()
106
+ self.N_slices = 5
107
+ for x in range(4):
108
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
109
+ for x in range(4, 9):
110
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(9, 16):
112
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(16, 23):
114
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(23, 30):
116
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
117
+ if not requires_grad:
118
+ for param in self.parameters():
119
+ param.requires_grad = False
120
+
121
+ def forward(self, X):
122
+ h = self.slice1(X)
123
+ h_relu1_2 = h
124
+ h = self.slice2(h)
125
+ h_relu2_2 = h
126
+ h = self.slice3(h)
127
+ h_relu3_3 = h
128
+ h = self.slice4(h)
129
+ h_relu4_3 = h
130
+ h = self.slice5(h)
131
+ h_relu5_3 = h
132
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134
+
135
+ return out
136
+
137
+
138
+
139
+ class resnet(torch.nn.Module):
140
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
141
+ super(resnet, self).__init__()
142
+ if(num==18):
143
+ self.net = tv.resnet18(pretrained=pretrained)
144
+ elif(num==34):
145
+ self.net = tv.resnet34(pretrained=pretrained)
146
+ elif(num==50):
147
+ self.net = tv.resnet50(pretrained=pretrained)
148
+ elif(num==101):
149
+ self.net = tv.resnet101(pretrained=pretrained)
150
+ elif(num==152):
151
+ self.net = tv.resnet152(pretrained=pretrained)
152
+ self.N_slices = 5
153
+
154
+ self.conv1 = self.net.conv1
155
+ self.bn1 = self.net.bn1
156
+ self.relu = self.net.relu
157
+ self.maxpool = self.net.maxpool
158
+ self.layer1 = self.net.layer1
159
+ self.layer2 = self.net.layer2
160
+ self.layer3 = self.net.layer3
161
+ self.layer4 = self.net.layer4
162
+
163
+ def forward(self, X):
164
+ h = self.conv1(X)
165
+ h = self.bn1(h)
166
+ h = self.relu(h)
167
+ h_relu1 = h
168
+ h = self.maxpool(h)
169
+ h = self.layer1(h)
170
+ h_conv2 = h
171
+ h = self.layer2(h)
172
+ h_conv3 = h
173
+ h = self.layer3(h)
174
+ h_conv4 = h
175
+ h = self.layer4(h)
176
+ h_conv5 = h
177
+
178
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180
+
181
+ return out
models/stylegan2/lpips/weights/v0.0/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
3
+ size 5455
models/stylegan2/lpips/weights/v0.0/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
3
+ size 10057
models/stylegan2/lpips/weights/v0.0/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
3
+ size 6735
models/stylegan2/lpips/weights/v0.1/alex.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
3
+ size 6009
models/stylegan2/lpips/weights/v0.1/squeeze.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
3
+ size 10811
models/stylegan2/lpips/weights/v0.1/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
models/stylegan2/model.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+
8
+ from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
9
+
10
+
11
+ class PixelNorm(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, input):
16
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
17
+
18
+
19
+ def make_kernel(k):
20
+ k = torch.tensor(k, dtype=torch.float32)
21
+
22
+ if k.ndim == 1:
23
+ k = k[None, :] * k[:, None]
24
+
25
+ k /= k.sum()
26
+
27
+ return k
28
+
29
+
30
+ class Upsample(nn.Module):
31
+ def __init__(self, kernel, factor=2):
32
+ super().__init__()
33
+
34
+ self.factor = factor
35
+ kernel = make_kernel(kernel) * (factor ** 2)
36
+ self.register_buffer('kernel', kernel)
37
+
38
+ p = kernel.shape[0] - factor
39
+
40
+ pad0 = (p + 1) // 2 + factor - 1
41
+ pad1 = p // 2
42
+
43
+ self.pad = (pad0, pad1)
44
+
45
+ def forward(self, input):
46
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
47
+
48
+ return out
49
+
50
+
51
+ class Downsample(nn.Module):
52
+ def __init__(self, kernel, factor=2):
53
+ super().__init__()
54
+
55
+ self.factor = factor
56
+ kernel = make_kernel(kernel)
57
+ self.register_buffer('kernel', kernel)
58
+
59
+ p = kernel.shape[0] - factor
60
+
61
+ pad0 = (p + 1) // 2
62
+ pad1 = p // 2
63
+
64
+ self.pad = (pad0, pad1)
65
+
66
+ def forward(self, input):
67
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
68
+
69
+ return out
70
+
71
+
72
+ class Blur(nn.Module):
73
+ def __init__(self, kernel, pad, upsample_factor=1):
74
+ super().__init__()
75
+
76
+ kernel = make_kernel(kernel)
77
+
78
+ if upsample_factor > 1:
79
+ kernel = kernel * (upsample_factor ** 2)
80
+
81
+ self.register_buffer('kernel', kernel)
82
+
83
+ self.pad = pad
84
+
85
+ def forward(self, input):
86
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
87
+
88
+ return out
89
+
90
+
91
+ class EqualConv2d(nn.Module):
92
+ def __init__(
93
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dilation=1 ## modified
94
+ ):
95
+ super().__init__()
96
+
97
+ self.weight = nn.Parameter(
98
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
99
+ )
100
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
101
+
102
+ self.stride = stride
103
+ self.padding = padding
104
+ self.dilation = dilation ## modified
105
+
106
+ if bias:
107
+ self.bias = nn.Parameter(torch.zeros(out_channel))
108
+
109
+ else:
110
+ self.bias = None
111
+
112
+ def forward(self, input):
113
+ out = F.conv2d(
114
+ input,
115
+ self.weight * self.scale,
116
+ bias=self.bias,
117
+ stride=self.stride,
118
+ padding=self.padding,
119
+ dilation=self.dilation, ## modified
120
+ )
121
+
122
+ return out
123
+
124
+ def __repr__(self):
125
+ return (
126
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
127
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding}, dilation={self.dilation})" ## modified
128
+ )
129
+
130
+
131
+ class EqualLinear(nn.Module):
132
+ def __init__(
133
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
134
+ ):
135
+ super().__init__()
136
+
137
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
138
+
139
+ if bias:
140
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
141
+
142
+ else:
143
+ self.bias = None
144
+
145
+ self.activation = activation
146
+
147
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
148
+ self.lr_mul = lr_mul
149
+
150
+ def forward(self, input):
151
+ if self.activation:
152
+ out = F.linear(input, self.weight * self.scale)
153
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
154
+
155
+ else:
156
+ out = F.linear(
157
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
158
+ )
159
+
160
+ return out
161
+
162
+ def __repr__(self):
163
+ return (
164
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
165
+ )
166
+
167
+
168
+ class ScaledLeakyReLU(nn.Module):
169
+ def __init__(self, negative_slope=0.2):
170
+ super().__init__()
171
+
172
+ self.negative_slope = negative_slope
173
+
174
+ def forward(self, input):
175
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
176
+
177
+ return out * math.sqrt(2)
178
+
179
+
180
+ class ModulatedConv2d(nn.Module):
181
+ def __init__(
182
+ self,
183
+ in_channel,
184
+ out_channel,
185
+ kernel_size,
186
+ style_dim,
187
+ demodulate=True,
188
+ upsample=False,
189
+ downsample=False,
190
+ blur_kernel=[1, 3, 3, 1],
191
+ dilation=1, ##### modified
192
+ ):
193
+ super().__init__()
194
+
195
+ self.eps = 1e-8
196
+ self.kernel_size = kernel_size
197
+ self.in_channel = in_channel
198
+ self.out_channel = out_channel
199
+ self.upsample = upsample
200
+ self.downsample = downsample
201
+ self.dilation = dilation ##### modified
202
+
203
+ if upsample:
204
+ factor = 2
205
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
206
+ pad0 = (p + 1) // 2 + factor - 1
207
+ pad1 = p // 2 + 1
208
+
209
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
210
+
211
+ # to simulate transconv + blur
212
+ # we use dilated transposed conv with blur kernel as weight + dilated transconv
213
+ if dilation > 1: ##### modified
214
+ blur_weight = torch.randn(1, 1, 3, 3) * 0 + 1
215
+ blur_weight[:,:,0,1] = 2
216
+ blur_weight[:,:,1,0] = 2
217
+ blur_weight[:,:,1,2] = 2
218
+ blur_weight[:,:,2,1] = 2
219
+ blur_weight[:,:,1,1] = 4
220
+ blur_weight = blur_weight / 16.0
221
+ self.register_buffer("blur_weight", blur_weight)
222
+
223
+ if downsample:
224
+ factor = 2
225
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
226
+ pad0 = (p + 1) // 2
227
+ pad1 = p // 2
228
+
229
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
230
+
231
+ fan_in = in_channel * kernel_size ** 2
232
+ self.scale = 1 / math.sqrt(fan_in)
233
+ self.padding = kernel_size // 2 + dilation - 1 ##### modified
234
+
235
+ self.weight = nn.Parameter(
236
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
237
+ )
238
+
239
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
240
+
241
+ self.demodulate = demodulate
242
+
243
+ def __repr__(self):
244
+ return (
245
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
246
+ f'upsample={self.upsample}, downsample={self.downsample})'
247
+ )
248
+
249
+ def forward(self, input, style):
250
+ batch, in_channel, height, width = input.shape
251
+
252
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
253
+ weight = self.scale * self.weight * style
254
+
255
+ if self.demodulate:
256
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
257
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
258
+
259
+ weight = weight.view(
260
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
261
+ )
262
+
263
+ if self.upsample:
264
+ input = input.view(1, batch * in_channel, height, width)
265
+ weight = weight.view(
266
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
267
+ )
268
+ weight = weight.transpose(1, 2).reshape(
269
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
270
+ )
271
+
272
+ if self.dilation > 1: ##### modified
273
+ # to simulate out = self.blur(out)
274
+ out = F.conv_transpose2d(
275
+ input, self.blur_weight.repeat(batch*in_channel,1,1,1), padding=0, groups=batch*in_channel, dilation=self.dilation//2)
276
+ # to simulate the next line
277
+ out = F.conv_transpose2d(
278
+ out, weight, padding=self.dilation, groups=batch, dilation=self.dilation//2)
279
+ _, _, height, width = out.shape
280
+ out = out.view(batch, self.out_channel, height, width)
281
+ return out
282
+
283
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
284
+ _, _, height, width = out.shape
285
+ out = out.view(batch, self.out_channel, height, width)
286
+ out = self.blur(out)
287
+
288
+ elif self.downsample:
289
+ input = self.blur(input)
290
+ _, _, height, width = input.shape
291
+ input = input.view(1, batch * in_channel, height, width)
292
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
293
+ _, _, height, width = out.shape
294
+ out = out.view(batch, self.out_channel, height, width)
295
+
296
+ else:
297
+ input = input.view(1, batch * in_channel, height, width)
298
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch, dilation=self.dilation) ##### modified
299
+ _, _, height, width = out.shape
300
+ out = out.view(batch, self.out_channel, height, width)
301
+
302
+ return out
303
+
304
+
305
+ class NoiseInjection(nn.Module):
306
+ def __init__(self):
307
+ super().__init__()
308
+
309
+ self.weight = nn.Parameter(torch.zeros(1))
310
+
311
+ def forward(self, image, noise=None):
312
+ if noise is None:
313
+ batch, _, height, width = image.shape
314
+ noise = image.new_empty(batch, 1, height, width).normal_()
315
+ else: ##### modified, to make the resolution matches
316
+ batch, _, height, width = image.shape
317
+ _, _, height1, width1 = noise.shape
318
+ if height != height1 or width != width1:
319
+ noise = F.adaptive_avg_pool2d(noise, (height, width))
320
+
321
+ return image + self.weight * noise
322
+
323
+
324
+ class ConstantInput(nn.Module):
325
+ def __init__(self, channel, size=4):
326
+ super().__init__()
327
+
328
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
329
+
330
+ def forward(self, input):
331
+ batch = input.shape[0]
332
+ out = self.input.repeat(batch, 1, 1, 1)
333
+
334
+ return out
335
+
336
+
337
+ class StyledConv(nn.Module):
338
+ def __init__(
339
+ self,
340
+ in_channel,
341
+ out_channel,
342
+ kernel_size,
343
+ style_dim,
344
+ upsample=False,
345
+ blur_kernel=[1, 3, 3, 1],
346
+ demodulate=True,
347
+ dilation=1, ##### modified
348
+ ):
349
+ super().__init__()
350
+
351
+ self.conv = ModulatedConv2d(
352
+ in_channel,
353
+ out_channel,
354
+ kernel_size,
355
+ style_dim,
356
+ upsample=upsample,
357
+ blur_kernel=blur_kernel,
358
+ demodulate=demodulate,
359
+ dilation=dilation, ##### modified
360
+ )
361
+
362
+ self.noise = NoiseInjection()
363
+ self.activate = FusedLeakyReLU(out_channel)
364
+
365
+ def forward(self, input, style, noise=None):
366
+ out = self.conv(input, style)
367
+ out = self.noise(out, noise=noise)
368
+ out = self.activate(out)
369
+
370
+ return out
371
+
372
+
373
+ class ToRGB(nn.Module):
374
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], dilation=1): ##### modified
375
+ super().__init__()
376
+
377
+ if upsample:
378
+ self.upsample = Upsample(blur_kernel)
379
+
380
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
381
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
382
+
383
+ self.dilation = dilation ##### modified
384
+ if dilation > 1: ##### modified
385
+ blur_weight = torch.randn(1, 1, 3, 3) * 0 + 1
386
+ blur_weight[:,:,0,1] = 2
387
+ blur_weight[:,:,1,0] = 2
388
+ blur_weight[:,:,1,2] = 2
389
+ blur_weight[:,:,2,1] = 2
390
+ blur_weight[:,:,1,1] = 4
391
+ blur_weight = blur_weight / 16.0
392
+ self.register_buffer("blur_weight", blur_weight)
393
+
394
+ def forward(self, input, style, skip=None):
395
+ out = self.conv(input, style)
396
+ out = out + self.bias
397
+
398
+ if skip is not None:
399
+ if self.dilation == 1:
400
+ skip = self.upsample(skip)
401
+ else: ##### modified, to simulate skip = self.upsample(skip)
402
+ batch, in_channel, _, _ = skip.shape
403
+ skip = F.conv2d(skip, self.blur_weight.repeat(in_channel,1,1,1),
404
+ padding=self.dilation//2, groups=in_channel, dilation=self.dilation//2)
405
+
406
+ out = out + skip
407
+
408
+ return out
409
+
410
+
411
+ class Generator(nn.Module):
412
+ def __init__(
413
+ self,
414
+ size,
415
+ style_dim,
416
+ n_mlp,
417
+ channel_multiplier=2,
418
+ blur_kernel=[1, 3, 3, 1],
419
+ lr_mlp=0.01,
420
+ ):
421
+ super().__init__()
422
+
423
+ self.size = size
424
+
425
+ self.style_dim = style_dim
426
+
427
+ layers = [PixelNorm()]
428
+
429
+ for i in range(n_mlp):
430
+ layers.append(
431
+ EqualLinear(
432
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
433
+ )
434
+ )
435
+
436
+ self.style = nn.Sequential(*layers)
437
+
438
+ self.channels = {
439
+ 4: 512,
440
+ 8: 512,
441
+ 16: 512,
442
+ 32: 512,
443
+ 64: 256 * channel_multiplier,
444
+ 128: 128 * channel_multiplier,
445
+ 256: 64 * channel_multiplier,
446
+ 512: 32 * channel_multiplier,
447
+ 1024: 16 * channel_multiplier,
448
+ }
449
+
450
+ self.input = ConstantInput(self.channels[4])
451
+ self.conv1 = StyledConv(
452
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel, dilation=8 ##### modified
453
+ )
454
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
455
+
456
+ self.log_size = int(math.log(size, 2))
457
+ self.num_layers = (self.log_size - 2) * 2 + 1
458
+
459
+ self.convs = nn.ModuleList()
460
+ self.upsamples = nn.ModuleList()
461
+ self.to_rgbs = nn.ModuleList()
462
+ self.noises = nn.Module()
463
+
464
+ in_channel = self.channels[4]
465
+
466
+ for layer_idx in range(self.num_layers):
467
+ res = (layer_idx + 5) // 2
468
+ shape = [1, 1, 2 ** res, 2 ** res]
469
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
470
+
471
+ for i in range(3, self.log_size + 1):
472
+ out_channel = self.channels[2 ** i]
473
+
474
+ self.convs.append(
475
+ StyledConv(
476
+ in_channel,
477
+ out_channel,
478
+ 3,
479
+ style_dim,
480
+ upsample=True,
481
+ blur_kernel=blur_kernel,
482
+ dilation=max(1, 32 // (2**(i-1))) ##### modified
483
+ )
484
+ )
485
+
486
+ self.convs.append(
487
+ StyledConv(
488
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel, dilation=max(1, 32 // (2**i)) ##### modified
489
+ )
490
+ )
491
+
492
+ self.to_rgbs.append(ToRGB(out_channel, style_dim, dilation=max(1, 32 // (2**(i-1))))) ##### modified
493
+
494
+ in_channel = out_channel
495
+
496
+ self.n_latent = self.log_size * 2 - 2
497
+
498
+ def make_noise(self):
499
+ device = self.input.input.device
500
+
501
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
502
+
503
+ for i in range(3, self.log_size + 1):
504
+ for _ in range(2):
505
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
506
+
507
+ return noises
508
+
509
+ def mean_latent(self, n_latent):
510
+ latent_in = torch.randn(
511
+ n_latent, self.style_dim, device=self.input.input.device
512
+ )
513
+ latent = self.style(latent_in).mean(0, keepdim=True)
514
+
515
+ return latent
516
+
517
+ def get_latent(self, input):
518
+ return self.style(input)
519
+
520
+ # styles is the latent code w+
521
+ # first_layer_feature is the first-layer input feature f
522
+ # first_layer_feature_ind indicate which layer of G accepts f (should always=0, the first layer)
523
+ # skip_layer_feature is the encoder features sent by skip connection
524
+ # fusion_block is the network to fuse the encoder feature and decoder feature
525
+ # zero_noise is to force the noise to be zero (to avoid flickers for videos)
526
+ # editing_w is the editing vector v used in video face editing
527
+ def forward(
528
+ self,
529
+ styles,
530
+ return_latents=False,
531
+ return_features=False,
532
+ inject_index=None,
533
+ truncation=1,
534
+ truncation_latent=None,
535
+ input_is_latent=False,
536
+ noise=None,
537
+ randomize_noise=True,
538
+ first_layer_feature = None, ##### modified
539
+ first_layer_feature_ind = 0, ##### modified
540
+ skip_layer_feature = None, ##### modified
541
+ fusion_block = None, ##### modified
542
+ zero_noise = False, ##### modified
543
+ editing_w = None, ##### modified
544
+ ):
545
+ if not input_is_latent:
546
+ styles = [self.style(s) for s in styles]
547
+
548
+ if zero_noise:
549
+ noise = [
550
+ getattr(self.noises, f'noise_{i}') * 0.0 for i in range(self.num_layers)
551
+ ]
552
+ elif noise is None:
553
+ if randomize_noise:
554
+ noise = [None] * self.num_layers
555
+ else:
556
+ noise = [
557
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
558
+ ]
559
+
560
+ if truncation < 1:
561
+ style_t = []
562
+
563
+ for style in styles:
564
+ style_t.append(
565
+ truncation_latent + truncation * (style - truncation_latent)
566
+ )
567
+
568
+ styles = style_t
569
+
570
+ if len(styles) < 2:
571
+ inject_index = self.n_latent
572
+
573
+ if styles[0].ndim < 3:
574
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
575
+ else:
576
+ latent = styles[0]
577
+
578
+ else:
579
+ if inject_index is None:
580
+ inject_index = random.randint(1, self.n_latent - 1)
581
+
582
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
583
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
584
+
585
+ latent = torch.cat([latent, latent2], 1)
586
+
587
+ # w+ + v for video face editing
588
+ if editing_w is not None: ##### modified
589
+ latent = latent + editing_w
590
+
591
+ # the original StyleGAN
592
+ if first_layer_feature is None: ##### modified
593
+ out = self.input(latent)
594
+ out = F.adaptive_avg_pool2d(out, 32) ##### modified
595
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
596
+ skip = self.to_rgb1(out, latent[:, 1])
597
+ # the default StyleGANEX, replacing the first layer of G
598
+ elif first_layer_feature_ind == 0: ##### modified
599
+ out = first_layer_feature[0] ##### modified
600
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
601
+ skip = self.to_rgb1(out, latent[:, 1])
602
+ # maybe we can also use the second layer of G to accept f?
603
+ else: ##### modified
604
+ out = first_layer_feature[0] ##### modified
605
+ skip = first_layer_feature[1] ##### modified
606
+
607
+ i = 1
608
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
609
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
610
+ ):
611
+ # these layers accepts skipped encoder layer, use fusion block to fuse the encoder feature and decoder feature
612
+ if skip_layer_feature and fusion_block and i//2 < len(skip_layer_feature) and i//2 < len(fusion_block):
613
+ if editing_w is None:
614
+ out, skip = fusion_block[i//2](skip_layer_feature[i//2], out, skip)
615
+ else:
616
+ out, skip = fusion_block[i//2](skip_layer_feature[i//2], out, skip, editing_w[:,i])
617
+ out = conv1(out, latent[:, i], noise=noise1)
618
+ out = conv2(out, latent[:, i + 1], noise=noise2)
619
+ skip = to_rgb(out, latent[:, i + 2], skip)
620
+
621
+ i += 2
622
+
623
+ image = skip
624
+
625
+ if return_latents:
626
+ return image, latent
627
+ elif return_features:
628
+ return image, out
629
+ else:
630
+ return image, None
631
+
632
+
633
+ class ConvLayer(nn.Sequential):
634
+ def __init__(
635
+ self,
636
+ in_channel,
637
+ out_channel,
638
+ kernel_size,
639
+ downsample=False,
640
+ blur_kernel=[1, 3, 3, 1],
641
+ bias=True,
642
+ activate=True,
643
+ dilation=1, ## modified
644
+ ):
645
+ layers = []
646
+
647
+ if downsample:
648
+ factor = 2
649
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
650
+ pad0 = (p + 1) // 2
651
+ pad1 = p // 2
652
+
653
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
654
+
655
+ stride = 2
656
+ self.padding = 0
657
+
658
+ else:
659
+ stride = 1
660
+ self.padding = kernel_size // 2 + dilation-1 ## modified
661
+
662
+ layers.append(
663
+ EqualConv2d(
664
+ in_channel,
665
+ out_channel,
666
+ kernel_size,
667
+ padding=self.padding,
668
+ stride=stride,
669
+ bias=bias and not activate,
670
+ dilation=dilation, ## modified
671
+ )
672
+ )
673
+
674
+ if activate:
675
+ if bias:
676
+ layers.append(FusedLeakyReLU(out_channel))
677
+
678
+ else:
679
+ layers.append(ScaledLeakyReLU(0.2))
680
+
681
+ super().__init__(*layers)
682
+
683
+
684
+ class ResBlock(nn.Module):
685
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
686
+ super().__init__()
687
+
688
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
689
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
690
+
691
+ self.skip = ConvLayer(
692
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
693
+ )
694
+
695
+ def forward(self, input):
696
+ out = self.conv1(input)
697
+ out = self.conv2(out)
698
+
699
+ skip = self.skip(input)
700
+ out = (out + skip) / math.sqrt(2)
701
+
702
+ return out
703
+
704
+
705
+ class Discriminator(nn.Module):
706
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], img_channel=3):
707
+ super().__init__()
708
+
709
+ channels = {
710
+ 4: 512,
711
+ 8: 512,
712
+ 16: 512,
713
+ 32: 512,
714
+ 64: 256 * channel_multiplier,
715
+ 128: 128 * channel_multiplier,
716
+ 256: 64 * channel_multiplier,
717
+ 512: 32 * channel_multiplier,
718
+ 1024: 16 * channel_multiplier,
719
+ }
720
+
721
+ convs = [ConvLayer(img_channel, channels[size], 1)]
722
+
723
+ log_size = int(math.log(size, 2))
724
+
725
+ in_channel = channels[size]
726
+
727
+ for i in range(log_size, 2, -1):
728
+ out_channel = channels[2 ** (i - 1)]
729
+
730
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
731
+
732
+ in_channel = out_channel
733
+
734
+ self.convs = nn.Sequential(*convs)
735
+
736
+ self.stddev_group = 4
737
+ self.stddev_feat = 1
738
+
739
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
740
+ self.final_linear = nn.Sequential(
741
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
742
+ EqualLinear(channels[4], 1),
743
+ )
744
+
745
+ self.size = size ##### modified
746
+
747
+ def forward(self, input):
748
+ # for input that not satisfies the target size, we crop it to extract a small image of the target size.
749
+ _, _, h, w = input.shape ##### modified
750
+ i, j = torch.randint(0, h+1-self.size, size=(1,)).item(), torch.randint(0, w+1-self.size, size=(1,)).item() ##### modified
751
+ out = self.convs(input[:,:,i:i+self.size,j:j+self.size]) ##### modified
752
+
753
+ batch, channel, height, width = out.shape
754
+ group = min(batch, self.stddev_group)
755
+ stddev = out.view(
756
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
757
+ )
758
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
759
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
760
+ stddev = stddev.repeat(group, 1, height, width)
761
+ out = torch.cat([out, stddev], 1)
762
+
763
+ out = self.final_conv(out)
764
+
765
+ out = out.view(batch, -1)
766
+ out = self.final_linear(out)
767
+
768
+ return out
models/stylegan2/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
models/stylegan2/op/conv2d_gradfix.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ warnings.warn(
89
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ )
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d
models/stylegan2/op/fused_act.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class FusedLeakyReLU(nn.Module):
7
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
8
+ super().__init__()
9
+
10
+ if bias:
11
+ self.bias = nn.Parameter(torch.zeros(channel))
12
+
13
+ else:
14
+ self.bias = None
15
+
16
+ self.negative_slope = negative_slope
17
+ self.scale = scale
18
+
19
+ def forward(self, inputs):
20
+ return fused_leaky_relu(inputs, self.bias, self.negative_slope, self.scale)
21
+
22
+
23
+ def fused_leaky_relu(inputs, bias=None, negative_slope=0.2, scale=2 ** 0.5):
24
+ if bias is not None:
25
+ rest_dim = [1] * (inputs.ndim - bias.ndim - 1)
26
+ return (
27
+ F.leaky_relu(
28
+ inputs + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
29
+ )
30
+ * scale
31
+ )
32
+
33
+ else:
34
+ return F.leaky_relu(inputs, negative_slope=negative_slope) * scale
models/stylegan2/op/readme.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Code from [rosinality-stylegan2-pytorch-cp](https://github.com/senior-sigan/rosinality-stylegan2-pytorch-cpu)
2
+
3
+ Scripts to convert rosinality/stylegan2-pytorch to the CPU compatible format
4
+
5
+ If you would like to use CPU for testing or have a problem regarding the cpp extention (fused and upfirdn2d), please make the following changes:
6
+
7
+ Change `model.stylegan.op` to `model.stylegan.op_cpu`
8
+ https://github.com/williamyang1991/VToonify/blob/01b383efc00007f9b069585db41a7d31a77a8806/util.py#L14
9
+
10
+ https://github.com/williamyang1991/VToonify/blob/01b383efc00007f9b069585db41a7d31a77a8806/model/simple_augment.py#L12
11
+
12
+ https://github.com/williamyang1991/VToonify/blob/01b383efc00007f9b069585db41a7d31a77a8806/model/stylegan/model.py#L11
models/stylegan2/op/upfirdn2d.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def upfirdn2d(inputs, kernel, up=1, down=1, pad=(0, 0)):
8
+ if not isinstance(up, abc.Iterable):
9
+ up = (up, up)
10
+
11
+ if not isinstance(down, abc.Iterable):
12
+ down = (down, down)
13
+
14
+ if len(pad) == 2:
15
+ pad = (pad[0], pad[1], pad[0], pad[1])
16
+
17
+ return upfirdn2d_native(inputs, kernel, *up, *down, *pad)
18
+
19
+
20
+ def upfirdn2d_native(
21
+ inputs, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
22
+ ):
23
+ _, channel, in_h, in_w = inputs.shape
24
+ inputs = inputs.reshape(-1, in_h, in_w, 1)
25
+
26
+ _, in_h, in_w, minor = inputs.shape
27
+ kernel_h, kernel_w = kernel.shape
28
+
29
+ out = inputs.view(-1, in_h, 1, in_w, 1, minor)
30
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
31
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
32
+
33
+ out = F.pad(
34
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
35
+ )
36
+ out = out[
37
+ :,
38
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
39
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
40
+ :,
41
+ ]
42
+
43
+ out = out.permute(0, 3, 1, 2)
44
+ out = out.reshape(
45
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
46
+ )
47
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
48
+ out = F.conv2d(out, w)
49
+ out = out.reshape(
50
+ -1,
51
+ minor,
52
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
53
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
54
+ )
55
+ out = out.permute(0, 2, 3, 1)
56
+ out = out[:, ::down_y, ::down_x, :]
57
+
58
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
59
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
60
+
61
+ return out.view(-1, channel, out_h, out_w)
models/stylegan2/op_ori/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
models/stylegan2/op_ori/fused_act.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+
8
+ module_path = os.path.dirname(__file__)
9
+ fused = load(
10
+ 'fused',
11
+ sources=[
12
+ os.path.join(module_path, 'fused_bias_act.cpp'),
13
+ os.path.join(module_path, 'fused_bias_act_kernel.cu'),
14
+ ],
15
+ )
16
+
17
+
18
+ class FusedLeakyReLUFunctionBackward(Function):
19
+ @staticmethod
20
+ def forward(ctx, grad_output, out, negative_slope, scale):
21
+ ctx.save_for_backward(out)
22
+ ctx.negative_slope = negative_slope
23
+ ctx.scale = scale
24
+
25
+ empty = grad_output.new_empty(0)
26
+
27
+ grad_input = fused.fused_bias_act(
28
+ grad_output, empty, out, 3, 1, negative_slope, scale
29
+ )
30
+
31
+ dim = [0]
32
+
33
+ if grad_input.ndim > 2:
34
+ dim += list(range(2, grad_input.ndim))
35
+
36
+ grad_bias = grad_input.sum(dim).detach()
37
+
38
+ return grad_input, grad_bias
39
+
40
+ @staticmethod
41
+ def backward(ctx, gradgrad_input, gradgrad_bias):
42
+ out, = ctx.saved_tensors
43
+ gradgrad_out = fused.fused_bias_act(
44
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
45
+ )
46
+
47
+ return gradgrad_out, None, None, None
48
+
49
+
50
+ class FusedLeakyReLUFunction(Function):
51
+ @staticmethod
52
+ def forward(ctx, input, bias, negative_slope, scale):
53
+ empty = input.new_empty(0)
54
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
55
+ ctx.save_for_backward(out)
56
+ ctx.negative_slope = negative_slope
57
+ ctx.scale = scale
58
+
59
+ return out
60
+
61
+ @staticmethod
62
+ def backward(ctx, grad_output):
63
+ out, = ctx.saved_tensors
64
+
65
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
66
+ grad_output, out, ctx.negative_slope, ctx.scale
67
+ )
68
+
69
+ return grad_input, grad_bias, None, None
70
+
71
+
72
+ class FusedLeakyReLU(nn.Module):
73
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
74
+ super().__init__()
75
+
76
+ self.bias = nn.Parameter(torch.zeros(channel))
77
+ self.negative_slope = negative_slope
78
+ self.scale = scale
79
+
80
+ def forward(self, input):
81
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
82
+
83
+
84
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
85
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
models/stylegan2/op_ori/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
models/stylegan2/op_ori/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
models/stylegan2/op_ori/upfirdn2d.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5
+ int up_x, int up_y, int down_x, int down_y,
6
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7
+
8
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11
+
12
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13
+ int up_x, int up_y, int down_x, int down_y,
14
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(kernel);
17
+
18
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23
+ }
models/stylegan2/op_ori/upfirdn2d.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.autograd import Function
5
+ from torch.utils.cpp_extension import load
6
+
7
+ module_path = os.path.dirname(__file__)
8
+ upfirdn2d_op = load(
9
+ 'upfirdn2d',
10
+ sources=[
11
+ os.path.join(module_path, 'upfirdn2d.cpp'),
12
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
13
+ ],
14
+ )
15
+
16
+
17
+ class UpFirDn2dBackward(Function):
18
+ @staticmethod
19
+ def forward(
20
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
21
+ ):
22
+ up_x, up_y = up
23
+ down_x, down_y = down
24
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
25
+
26
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
27
+
28
+ grad_input = upfirdn2d_op.upfirdn2d(
29
+ grad_output,
30
+ grad_kernel,
31
+ down_x,
32
+ down_y,
33
+ up_x,
34
+ up_y,
35
+ g_pad_x0,
36
+ g_pad_x1,
37
+ g_pad_y0,
38
+ g_pad_y1,
39
+ )
40
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
41
+
42
+ ctx.save_for_backward(kernel)
43
+
44
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
45
+
46
+ ctx.up_x = up_x
47
+ ctx.up_y = up_y
48
+ ctx.down_x = down_x
49
+ ctx.down_y = down_y
50
+ ctx.pad_x0 = pad_x0
51
+ ctx.pad_x1 = pad_x1
52
+ ctx.pad_y0 = pad_y0
53
+ ctx.pad_y1 = pad_y1
54
+ ctx.in_size = in_size
55
+ ctx.out_size = out_size
56
+
57
+ return grad_input
58
+
59
+ @staticmethod
60
+ def backward(ctx, gradgrad_input):
61
+ kernel, = ctx.saved_tensors
62
+
63
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
64
+
65
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
66
+ gradgrad_input,
67
+ kernel,
68
+ ctx.up_x,
69
+ ctx.up_y,
70
+ ctx.down_x,
71
+ ctx.down_y,
72
+ ctx.pad_x0,
73
+ ctx.pad_x1,
74
+ ctx.pad_y0,
75
+ ctx.pad_y1,
76
+ )
77
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
78
+ gradgrad_out = gradgrad_out.view(
79
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
80
+ )
81
+
82
+ return gradgrad_out, None, None, None, None, None, None, None, None
83
+
84
+
85
+ class UpFirDn2d(Function):
86
+ @staticmethod
87
+ def forward(ctx, input, kernel, up, down, pad):
88
+ up_x, up_y = up
89
+ down_x, down_y = down
90
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
91
+
92
+ kernel_h, kernel_w = kernel.shape
93
+ batch, channel, in_h, in_w = input.shape
94
+ ctx.in_size = input.shape
95
+
96
+ input = input.reshape(-1, in_h, in_w, 1)
97
+
98
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
99
+
100
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
101
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
102
+ ctx.out_size = (out_h, out_w)
103
+
104
+ ctx.up = (up_x, up_y)
105
+ ctx.down = (down_x, down_y)
106
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
107
+
108
+ g_pad_x0 = kernel_w - pad_x0 - 1
109
+ g_pad_y0 = kernel_h - pad_y0 - 1
110
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
111
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
112
+
113
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
114
+
115
+ out = upfirdn2d_op.upfirdn2d(
116
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
117
+ )
118
+ # out = out.view(major, out_h, out_w, minor)
119
+ out = out.view(-1, channel, out_h, out_w)
120
+
121
+ return out
122
+
123
+ @staticmethod
124
+ def backward(ctx, grad_output):
125
+ kernel, grad_kernel = ctx.saved_tensors
126
+
127
+ grad_input = UpFirDn2dBackward.apply(
128
+ grad_output,
129
+ kernel,
130
+ grad_kernel,
131
+ ctx.up,
132
+ ctx.down,
133
+ ctx.pad,
134
+ ctx.g_pad,
135
+ ctx.in_size,
136
+ ctx.out_size,
137
+ )
138
+
139
+ return grad_input, None, None, None, None
140
+
141
+
142
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
143
+ out = UpFirDn2d.apply(
144
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
145
+ )
146
+
147
+ return out
148
+
149
+
150
+ def upfirdn2d_native(
151
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
152
+ ):
153
+ _, in_h, in_w, minor = input.shape
154
+ kernel_h, kernel_w = kernel.shape
155
+
156
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
157
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
158
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
159
+
160
+ out = F.pad(
161
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
162
+ )
163
+ out = out[
164
+ :,
165
+ max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
166
+ max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
167
+ :,
168
+ ]
169
+
170
+ out = out.permute(0, 3, 1, 2)
171
+ out = out.reshape(
172
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
173
+ )
174
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
175
+ out = F.conv2d(out, w)
176
+ out = out.reshape(
177
+ -1,
178
+ minor,
179
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
180
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
181
+ )
182
+ out = out.permute(0, 2, 3, 1)
183
+
184
+ return out[:, ::down_y, ::down_x, :]
models/stylegan2/op_ori/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19
+ int c = a / b;
20
+
21
+ if (c * b > a) {
22
+ c--;
23
+ }
24
+
25
+ return c;
26
+ }
27
+
28
+
29
+ struct UpFirDn2DKernelParams {
30
+ int up_x;
31
+ int up_y;
32
+ int down_x;
33
+ int down_y;
34
+ int pad_x0;
35
+ int pad_x1;
36
+ int pad_y0;
37
+ int pad_y1;
38
+
39
+ int major_dim;
40
+ int in_h;
41
+ int in_w;
42
+ int minor_dim;
43
+ int kernel_h;
44
+ int kernel_w;
45
+ int out_h;
46
+ int out_w;
47
+ int loop_major;
48
+ int loop_x;
49
+ };
50
+
51
+
52
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
53
+ __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56
+
57
+ __shared__ volatile float sk[kernel_h][kernel_w];
58
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
59
+
60
+ int minor_idx = blockIdx.x;
61
+ int tile_out_y = minor_idx / p.minor_dim;
62
+ minor_idx -= tile_out_y * p.minor_dim;
63
+ tile_out_y *= tile_out_h;
64
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65
+ int major_idx_base = blockIdx.z * p.loop_major;
66
+
67
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68
+ return;
69
+ }
70
+
71
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72
+ int ky = tap_idx / kernel_w;
73
+ int kx = tap_idx - ky * kernel_w;
74
+ scalar_t v = 0.0;
75
+
76
+ if (kx < p.kernel_w & ky < p.kernel_h) {
77
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78
+ }
79
+
80
+ sk[ky][kx] = v;
81
+ }
82
+
83
+ for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84
+ for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87
+ int tile_in_x = floor_div(tile_mid_x, up_x);
88
+ int tile_in_y = floor_div(tile_mid_y, up_y);
89
+
90
+ __syncthreads();
91
+
92
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93
+ int rel_in_y = in_idx / tile_in_w;
94
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
95
+ int in_x = rel_in_x + tile_in_x;
96
+ int in_y = rel_in_y + tile_in_y;
97
+
98
+ scalar_t v = 0.0;
99
+
100
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102
+ }
103
+
104
+ sx[rel_in_y][rel_in_x] = v;
105
+ }
106
+
107
+ __syncthreads();
108
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109
+ int rel_out_y = out_idx / tile_out_w;
110
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
111
+ int out_x = rel_out_x + tile_out_x;
112
+ int out_y = rel_out_y + tile_out_y;
113
+
114
+ int mid_x = tile_mid_x + rel_out_x * down_x;
115
+ int mid_y = tile_mid_y + rel_out_y * down_y;
116
+ int in_x = floor_div(mid_x, up_x);
117
+ int in_y = floor_div(mid_y, up_y);
118
+ int rel_in_x = in_x - tile_in_x;
119
+ int rel_in_y = in_y - tile_in_y;
120
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122
+
123
+ scalar_t v = 0.0;
124
+
125
+ #pragma unroll
126
+ for (int y = 0; y < kernel_h / up_y; y++)
127
+ #pragma unroll
128
+ for (int x = 0; x < kernel_w / up_x; x++)
129
+ v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130
+
131
+ if (out_x < p.out_w & out_y < p.out_h) {
132
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+
139
+
140
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141
+ int up_x, int up_y, int down_x, int down_y,
142
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143
+ int curDevice = -1;
144
+ cudaGetDevice(&curDevice);
145
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146
+
147
+ UpFirDn2DKernelParams p;
148
+
149
+ auto x = input.contiguous();
150
+ auto k = kernel.contiguous();
151
+
152
+ p.major_dim = x.size(0);
153
+ p.in_h = x.size(1);
154
+ p.in_w = x.size(2);
155
+ p.minor_dim = x.size(3);
156
+ p.kernel_h = k.size(0);
157
+ p.kernel_w = k.size(1);
158
+ p.up_x = up_x;
159
+ p.up_y = up_y;
160
+ p.down_x = down_x;
161
+ p.down_y = down_y;
162
+ p.pad_x0 = pad_x0;
163
+ p.pad_x1 = pad_x1;
164
+ p.pad_y0 = pad_y0;
165
+ p.pad_y1 = pad_y1;
166
+
167
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169
+
170
+ auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171
+
172
+ int mode = -1;
173
+
174
+ int tile_out_h;
175
+ int tile_out_w;
176
+
177
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178
+ mode = 1;
179
+ tile_out_h = 16;
180
+ tile_out_w = 64;
181
+ }
182
+
183
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184
+ mode = 2;
185
+ tile_out_h = 16;
186
+ tile_out_w = 64;
187
+ }
188
+
189
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190
+ mode = 3;
191
+ tile_out_h = 16;
192
+ tile_out_w = 64;
193
+ }
194
+
195
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196
+ mode = 4;
197
+ tile_out_h = 16;
198
+ tile_out_w = 64;
199
+ }
200
+
201
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202
+ mode = 5;
203
+ tile_out_h = 8;
204
+ tile_out_w = 32;
205
+ }
206
+
207
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208
+ mode = 6;
209
+ tile_out_h = 8;
210
+ tile_out_w = 32;
211
+ }
212
+
213
+ dim3 block_size;
214
+ dim3 grid_size;
215
+
216
+ if (tile_out_h > 0 && tile_out_w) {
217
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
218
+ p.loop_x = 1;
219
+ block_size = dim3(32 * 8, 1, 1);
220
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222
+ (p.major_dim - 1) / p.loop_major + 1);
223
+ }
224
+
225
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226
+ switch (mode) {
227
+ case 1:
228
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
229
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
230
+ );
231
+
232
+ break;
233
+
234
+ case 2:
235
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
236
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
237
+ );
238
+
239
+ break;
240
+
241
+ case 3:
242
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
243
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
244
+ );
245
+
246
+ break;
247
+
248
+ case 4:
249
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
250
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
251
+ );
252
+
253
+ break;
254
+
255
+ case 5:
256
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
257
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
258
+ );
259
+
260
+ break;
261
+
262
+ case 6:
263
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
264
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
265
+ );
266
+
267
+ break;
268
+ }
269
+ });
270
+
271
+ return out;
272
+ }
models/stylegan2/simple_augment.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import autograd
5
+ from torch.nn import functional as F
6
+ import numpy as np
7
+
8
+ from torch import distributed as dist
9
+ #from distributed import reduce_sum
10
+ from models.stylegan2.op2 import upfirdn2d
11
+
12
+ def reduce_sum(tensor):
13
+ if not dist.is_available():
14
+ return tensor
15
+
16
+ if not dist.is_initialized():
17
+ return tensor
18
+
19
+ tensor = tensor.clone()
20
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
21
+
22
+ return tensor
23
+
24
+
25
+ class AdaptiveAugment:
26
+ def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
27
+ self.ada_aug_target = ada_aug_target
28
+ self.ada_aug_len = ada_aug_len
29
+ self.update_every = update_every
30
+
31
+ self.ada_update = 0
32
+ self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
33
+ self.r_t_stat = 0
34
+ self.ada_aug_p = 0
35
+
36
+ @torch.no_grad()
37
+ def tune(self, real_pred):
38
+ self.ada_aug_buf += torch.tensor(
39
+ (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
40
+ device=real_pred.device,
41
+ )
42
+ self.ada_update += 1
43
+
44
+ if self.ada_update % self.update_every == 0:
45
+ self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
46
+ pred_signs, n_pred = self.ada_aug_buf.tolist()
47
+
48
+ self.r_t_stat = pred_signs / n_pred
49
+
50
+ if self.r_t_stat > self.ada_aug_target:
51
+ sign = 1
52
+
53
+ else:
54
+ sign = -1
55
+
56
+ self.ada_aug_p += sign * n_pred / self.ada_aug_len
57
+ self.ada_aug_p = min(1, max(0, self.ada_aug_p))
58
+ self.ada_aug_buf.mul_(0)
59
+ self.ada_update = 0
60
+
61
+ return self.ada_aug_p
62
+
63
+
64
+ SYM6 = (
65
+ 0.015404109327027373,
66
+ 0.0034907120842174702,
67
+ -0.11799011114819057,
68
+ -0.048311742585633,
69
+ 0.4910559419267466,
70
+ 0.787641141030194,
71
+ 0.3379294217276218,
72
+ -0.07263752278646252,
73
+ -0.021060292512300564,
74
+ 0.04472490177066578,
75
+ 0.0017677118642428036,
76
+ -0.007800708325034148,
77
+ )
78
+
79
+
80
+ def translate_mat(t_x, t_y, device="cpu"):
81
+ batch = t_x.shape[0]
82
+
83
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
84
+ translate = torch.stack((t_x, t_y), 1)
85
+ mat[:, :2, 2] = translate
86
+
87
+ return mat
88
+
89
+
90
+ def rotate_mat(theta, device="cpu"):
91
+ batch = theta.shape[0]
92
+
93
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
94
+ sin_t = torch.sin(theta)
95
+ cos_t = torch.cos(theta)
96
+ rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
97
+ mat[:, :2, :2] = rot
98
+
99
+ return mat
100
+
101
+
102
+ def scale_mat(s_x, s_y, device="cpu"):
103
+ batch = s_x.shape[0]
104
+
105
+ mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
106
+ mat[:, 0, 0] = s_x
107
+ mat[:, 1, 1] = s_y
108
+
109
+ return mat
110
+
111
+
112
+ def translate3d_mat(t_x, t_y, t_z):
113
+ batch = t_x.shape[0]
114
+
115
+ mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
116
+ translate = torch.stack((t_x, t_y, t_z), 1)
117
+ mat[:, :3, 3] = translate
118
+
119
+ return mat
120
+
121
+
122
+ def rotate3d_mat(axis, theta):
123
+ batch = theta.shape[0]
124
+
125
+ u_x, u_y, u_z = axis
126
+
127
+ eye = torch.eye(3).unsqueeze(0)
128
+ cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
129
+ outer = torch.tensor(axis)
130
+ outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
131
+
132
+ sin_t = torch.sin(theta).view(-1, 1, 1)
133
+ cos_t = torch.cos(theta).view(-1, 1, 1)
134
+
135
+ rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
136
+
137
+ eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
138
+ eye_4[:, :3, :3] = rot
139
+
140
+ return eye_4
141
+
142
+
143
+ def scale3d_mat(s_x, s_y, s_z):
144
+ batch = s_x.shape[0]
145
+
146
+ mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
147
+ mat[:, 0, 0] = s_x
148
+ mat[:, 1, 1] = s_y
149
+ mat[:, 2, 2] = s_z
150
+
151
+ return mat
152
+
153
+
154
+ def luma_flip_mat(axis, i):
155
+ batch = i.shape[0]
156
+
157
+ eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
158
+ axis = torch.tensor(axis + (0,))
159
+ flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
160
+
161
+ return eye - flip
162
+
163
+
164
+ def saturation_mat(axis, i):
165
+ batch = i.shape[0]
166
+
167
+ eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
168
+ axis = torch.tensor(axis + (0,))
169
+ axis = torch.ger(axis, axis)
170
+ saturate = axis + (eye - axis) * i.view(-1, 1, 1)
171
+
172
+ return saturate
173
+
174
+
175
+ def lognormal_sample(size, mean=0, std=1, device="cpu"):
176
+ return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
177
+
178
+
179
+ def category_sample(size, categories, device="cpu"):
180
+ category = torch.tensor(categories, device=device)
181
+ sample = torch.randint(high=len(categories), size=(size,), device=device)
182
+
183
+ return category[sample]
184
+
185
+
186
+ def uniform_sample(size, low, high, device="cpu"):
187
+ return torch.empty(size, device=device).uniform_(low, high)
188
+
189
+
190
+ def normal_sample(size, mean=0, std=1, device="cpu"):
191
+ return torch.empty(size, device=device).normal_(mean, std)
192
+
193
+
194
+ def bernoulli_sample(size, p, device="cpu"):
195
+ return torch.empty(size, device=device).bernoulli_(p)
196
+
197
+
198
+ def random_mat_apply(p, transform, prev, eye, device="cpu"):
199
+ size = transform.shape[0]
200
+ select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
201
+ select_transform = select * transform + (1 - select) * eye
202
+
203
+ return select_transform @ prev
204
+
205
+
206
+ def sample_affine(p, size, height, width, device="cpu"):
207
+ G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
208
+ eye = G
209
+
210
+ # flip
211
+ #param = category_sample(size, (0, 1))
212
+ #Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
213
+ #G = random_mat_apply(p, Gc, G, eye, device=device)
214
+ # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
215
+
216
+ # 90 rotate
217
+ #param = category_sample(size, (0, 3))
218
+ #Gc = rotate_mat(-math.pi / 2 * param, device=device)
219
+ #G = random_mat_apply(p, Gc, G, eye, device=device)
220
+ # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
221
+
222
+ # integer translate
223
+ param = uniform_sample(size, -0.125, 0.125)
224
+ param_height = torch.round(param * height) / height
225
+ param_width = torch.round(param * width) / width
226
+ Gc = translate_mat(param_width, param_height, device=device)
227
+ G = random_mat_apply(p, Gc, G, eye, device=device)
228
+ # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
229
+
230
+ # isotropic scale
231
+ param = lognormal_sample(size, std=0.1 * math.log(2))
232
+ Gc = scale_mat(param, param, device=device)
233
+ G = random_mat_apply(p, Gc, G, eye, device=device)
234
+ # print('isotropic scale', G, scale_mat(param, param), sep='\n')
235
+
236
+ p_rot = 1 - math.sqrt(1 - p)
237
+
238
+ # pre-rotate
239
+ param = uniform_sample(size, -math.pi * 0.25, math.pi * 0.25)
240
+ Gc = rotate_mat(-param, device=device)
241
+ G = random_mat_apply(p_rot, Gc, G, eye, device=device)
242
+ # print('pre-rotate', G, rotate_mat(-param), sep='\n')
243
+
244
+ # anisotropic scale
245
+ param = lognormal_sample(size, std=0.1 * math.log(2))
246
+ Gc = scale_mat(param, 1 / param, device=device)
247
+ G = random_mat_apply(p, Gc, G, eye, device=device)
248
+ # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
249
+
250
+ # post-rotate
251
+ param = uniform_sample(size, -math.pi * 0.25, math.pi * 0.25)
252
+ Gc = rotate_mat(-param, device=device)
253
+ G = random_mat_apply(p_rot, Gc, G, eye, device=device)
254
+ # print('post-rotate', G, rotate_mat(-param), sep='\n')
255
+
256
+ # fractional translate
257
+ param = normal_sample(size, std=0.125)
258
+ Gc = translate_mat(param, param, device=device)
259
+ G = random_mat_apply(p, Gc, G, eye, device=device)
260
+ # print('fractional translate', G, translate_mat(param, param), sep='\n')
261
+
262
+ return G
263
+
264
+
265
+ def sample_color(p, size):
266
+ C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
267
+ eye = C
268
+ axis_val = 1 / math.sqrt(3)
269
+ axis = (axis_val, axis_val, axis_val)
270
+
271
+ # brightness
272
+ param = normal_sample(size, std=0.2)
273
+ Cc = translate3d_mat(param, param, param)
274
+ C = random_mat_apply(p, Cc, C, eye)
275
+
276
+ # contrast
277
+ param = lognormal_sample(size, std=0.5 * math.log(2))
278
+ Cc = scale3d_mat(param, param, param)
279
+ C = random_mat_apply(p, Cc, C, eye)
280
+
281
+ # luma flip
282
+ param = category_sample(size, (0, 1))
283
+ Cc = luma_flip_mat(axis, param)
284
+ C = random_mat_apply(p, Cc, C, eye)
285
+
286
+ # hue rotation
287
+ param = uniform_sample(size, -math.pi, math.pi)
288
+ Cc = rotate3d_mat(axis, param)
289
+ C = random_mat_apply(p, Cc, C, eye)
290
+
291
+ # saturation
292
+ param = lognormal_sample(size, std=1 * math.log(2))
293
+ Cc = saturation_mat(axis, param)
294
+ C = random_mat_apply(p, Cc, C, eye)
295
+
296
+ return C
297
+
298
+
299
+ def make_grid(shape, x0, x1, y0, y1, device):
300
+ n, c, h, w = shape
301
+ grid = torch.empty(n, h, w, 3, device=device)
302
+ grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
303
+ grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
304
+ grid[:, :, :, 2] = 1
305
+
306
+ return grid
307
+
308
+
309
+ def affine_grid(grid, mat):
310
+ n, h, w, _ = grid.shape
311
+ return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
312
+
313
+
314
+ def get_padding(G, height, width, kernel_size):
315
+ device = G.device
316
+
317
+ cx = (width - 1) / 2
318
+ cy = (height - 1) / 2
319
+ cp = torch.tensor(
320
+ [(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
321
+ )
322
+ cp = G @ cp.T
323
+
324
+ pad_k = kernel_size // 4
325
+
326
+ pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
327
+ pad = torch.cat((-pad, pad)).max(1).values
328
+ pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
329
+ pad = pad.max(torch.tensor([0, 0] * 2, device=device))
330
+ pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
331
+
332
+ pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
333
+
334
+ return pad_x1, pad_x2, pad_y1, pad_y2
335
+
336
+
337
+ def try_sample_affine_and_pad(img, p, kernel_size, G=None):
338
+ batch, _, height, width = img.shape
339
+
340
+ G_try = G
341
+
342
+ if G is None:
343
+ G_try = torch.inverse(sample_affine(p, batch, height, width))
344
+
345
+ pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
346
+
347
+ img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
348
+
349
+ return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
350
+
351
+
352
+ class GridSampleForward(autograd.Function):
353
+ @staticmethod
354
+ def forward(ctx, input, grid):
355
+ out = F.grid_sample(
356
+ input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
357
+ )
358
+ ctx.save_for_backward(input, grid)
359
+
360
+ return out
361
+
362
+ @staticmethod
363
+ def backward(ctx, grad_output):
364
+ input, grid = ctx.saved_tensors
365
+ grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
366
+
367
+ return grad_input, grad_grid
368
+
369
+
370
+ class GridSampleBackward(autograd.Function):
371
+ @staticmethod
372
+ def forward(ctx, grad_output, input, grid):
373
+ op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
374
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
375
+ ctx.save_for_backward(grid)
376
+
377
+ return grad_input, grad_grid
378
+
379
+ @staticmethod
380
+ def backward(ctx, grad_grad_input, grad_grad_grid):
381
+ grid, = ctx.saved_tensors
382
+ grad_grad_output = None
383
+
384
+ if ctx.needs_input_grad[0]:
385
+ grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
386
+
387
+ return grad_grad_output, None, None
388
+
389
+
390
+ grid_sample = GridSampleForward.apply
391
+
392
+
393
+ def scale_mat_single(s_x, s_y):
394
+ return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
395
+
396
+
397
+ def translate_mat_single(t_x, t_y):
398
+ return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
399
+
400
+
401
+ def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
402
+ kernel = antialiasing_kernel
403
+ len_k = len(kernel)
404
+
405
+ kernel = torch.as_tensor(kernel).to(img)
406
+ # kernel = torch.ger(kernel, kernel).to(img)
407
+ kernel_flip = torch.flip(kernel, (0,))
408
+
409
+ img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
410
+ img, p, len_k, G
411
+ )
412
+
413
+ G_inv = (
414
+ translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
415
+ @ G
416
+ )
417
+ up_pad = (
418
+ (len_k + 2 - 1) // 2,
419
+ (len_k - 2) // 2,
420
+ (len_k + 2 - 1) // 2,
421
+ (len_k - 2) // 2,
422
+ )
423
+ img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
424
+ img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
425
+ G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
426
+ G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
427
+ batch_size, channel, height, width = img.shape
428
+ pad_k = len_k // 4
429
+ shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
430
+ G_inv = (
431
+ scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
432
+ @ G_inv
433
+ @ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
434
+ )
435
+ grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
436
+ img_affine = grid_sample(img_2x, grid)
437
+ d_p = -pad_k * 2
438
+ down_pad = (
439
+ d_p + (len_k - 2 + 1) // 2,
440
+ d_p + (len_k - 2) // 2,
441
+ d_p + (len_k - 2 + 1) // 2,
442
+ d_p + (len_k - 2) // 2,
443
+ )
444
+ img_down = upfirdn2d(
445
+ img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
446
+ )
447
+ img_down = upfirdn2d(
448
+ img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
449
+ )
450
+
451
+ return img_down, G
452
+
453
+
454
+ def apply_color(img, mat):
455
+ batch = img.shape[0]
456
+ img = img.permute(0, 2, 3, 1)
457
+ mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
458
+ mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
459
+ img = img @ mat_mul + mat_add
460
+ img = img.permute(0, 3, 1, 2)
461
+
462
+ return img
463
+
464
+
465
+ def random_apply_color(img, p, C=None):
466
+ if C is None:
467
+ C = sample_color(p, img.shape[0])
468
+
469
+ img = apply_color(img, C.to(img))
470
+
471
+ return img, C
472
+
473
+
474
+ def augment(img, p, transform_matrix=(None, None)):
475
+ img, G = random_apply_affine(img, p, transform_matrix[0])
476
+ img, C = random_apply_color(img, p, transform_matrix[1])
477
+
478
+ return img, (G, C)