Spaces:
Sleeping
Sleeping
PKUWilliamYang
commited on
Commit
•
ac4ce84
1
Parent(s):
d5073e2
Upload 50 files
Browse files- models/__init__.py +0 -0
- models/bisenet/LICENSE +21 -0
- models/bisenet/README.md +68 -0
- models/bisenet/model.py +283 -0
- models/bisenet/resnet.py +109 -0
- models/encoders/__init__.py +0 -0
- models/encoders/helpers.py +119 -0
- models/encoders/model_irse.py +84 -0
- models/encoders/psp_encoders.py +357 -0
- models/mtcnn/__init__.py +0 -0
- models/mtcnn/mtcnn.py +156 -0
- models/mtcnn/mtcnn_pytorch/__init__.py +0 -0
- models/mtcnn/mtcnn_pytorch/src/__init__.py +2 -0
- models/mtcnn/mtcnn_pytorch/src/align_trans.py +304 -0
- models/mtcnn/mtcnn_pytorch/src/box_utils.py +238 -0
- models/mtcnn/mtcnn_pytorch/src/detector.py +126 -0
- models/mtcnn/mtcnn_pytorch/src/first_stage.py +101 -0
- models/mtcnn/mtcnn_pytorch/src/get_nets.py +171 -0
- models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py +350 -0
- models/mtcnn/mtcnn_pytorch/src/visualization_utils.py +31 -0
- models/mtcnn/mtcnn_pytorch/src/weights/onet.npy +3 -0
- models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy +3 -0
- models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy +3 -0
- models/psp.py +147 -0
- models/stylegan2/__init__.py +0 -0
- models/stylegan2/lpips/__init__.py +161 -0
- models/stylegan2/lpips/base_model.py +58 -0
- models/stylegan2/lpips/dist_model.py +284 -0
- models/stylegan2/lpips/networks_basic.py +187 -0
- models/stylegan2/lpips/pretrained_networks.py +181 -0
- models/stylegan2/lpips/weights/v0.0/alex.pth +3 -0
- models/stylegan2/lpips/weights/v0.0/squeeze.pth +3 -0
- models/stylegan2/lpips/weights/v0.0/vgg.pth +3 -0
- models/stylegan2/lpips/weights/v0.1/alex.pth +3 -0
- models/stylegan2/lpips/weights/v0.1/squeeze.pth +3 -0
- models/stylegan2/lpips/weights/v0.1/vgg.pth +3 -0
- models/stylegan2/model.py +768 -0
- models/stylegan2/op/__init__.py +2 -0
- models/stylegan2/op/conv2d_gradfix.py +227 -0
- models/stylegan2/op/fused_act.py +34 -0
- models/stylegan2/op/readme.md +12 -0
- models/stylegan2/op/upfirdn2d.py +61 -0
- models/stylegan2/op_ori/__init__.py +2 -0
- models/stylegan2/op_ori/fused_act.py +85 -0
- models/stylegan2/op_ori/fused_bias_act.cpp +21 -0
- models/stylegan2/op_ori/fused_bias_act_kernel.cu +99 -0
- models/stylegan2/op_ori/upfirdn2d.cpp +23 -0
- models/stylegan2/op_ori/upfirdn2d.py +184 -0
- models/stylegan2/op_ori/upfirdn2d_kernel.cu +272 -0
- models/stylegan2/simple_augment.py +478 -0
models/__init__.py
ADDED
File without changes
|
models/bisenet/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 zll
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
models/bisenet/README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# face-parsing.PyTorch
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<a href="https://github.com/zllrunning/face-parsing.PyTorch">
|
5 |
+
<img class="page-image" src="https://github.com/zllrunning/face-parsing.PyTorch/blob/master/6.jpg" >
|
6 |
+
</a>
|
7 |
+
</p>
|
8 |
+
|
9 |
+
### Contents
|
10 |
+
- [Training](#training)
|
11 |
+
- [Demo](#Demo)
|
12 |
+
- [References](#references)
|
13 |
+
|
14 |
+
## Training
|
15 |
+
|
16 |
+
1. Prepare training data:
|
17 |
+
-- download [CelebAMask-HQ dataset](https://github.com/switchablenorms/CelebAMask-HQ)
|
18 |
+
|
19 |
+
-- change file path in the `prepropess_data.py` and run
|
20 |
+
```Shell
|
21 |
+
python prepropess_data.py
|
22 |
+
```
|
23 |
+
|
24 |
+
2. Train the model using CelebAMask-HQ dataset:
|
25 |
+
Just run the train script:
|
26 |
+
```
|
27 |
+
$ CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
|
28 |
+
```
|
29 |
+
|
30 |
+
If you do not wish to train the model, you can download [our pre-trained model](https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812) and save it in `res/cp`.
|
31 |
+
|
32 |
+
|
33 |
+
## Demo
|
34 |
+
1. Evaluate the trained model using:
|
35 |
+
```Shell
|
36 |
+
# evaluate using GPU
|
37 |
+
python test.py
|
38 |
+
```
|
39 |
+
|
40 |
+
## Face makeup using parsing maps
|
41 |
+
[**face-makeup.PyTorch**](https://github.com/zllrunning/face-makeup.PyTorch)
|
42 |
+
<table>
|
43 |
+
|
44 |
+
<tr>
|
45 |
+
<th> </th>
|
46 |
+
<th>Hair</th>
|
47 |
+
<th>Lip</th>
|
48 |
+
</tr>
|
49 |
+
|
50 |
+
<!-- Line 1: Original Input -->
|
51 |
+
<tr>
|
52 |
+
<td><em>Original Input</em></td>
|
53 |
+
<td><img src="makeup/116_ori.png" height="256" width="256" alt="Original Input"></td>
|
54 |
+
<td><img src="makeup/116_lip_ori.png" height="256" width="256" alt="Original Input"></td>
|
55 |
+
</tr>
|
56 |
+
|
57 |
+
<!-- Line 3: Color -->
|
58 |
+
<tr>
|
59 |
+
<td>Color</td>
|
60 |
+
<td><img src="makeup/116_1.png" height="256" width="256" alt="Color"></td>
|
61 |
+
<td><img src="makeup/116_3.png" height="256" width="256" alt="Color"></td>
|
62 |
+
</tr>
|
63 |
+
|
64 |
+
</table>
|
65 |
+
|
66 |
+
|
67 |
+
## References
|
68 |
+
- [BiSeNet](https://github.com/CoinCheung/BiSeNet)
|
models/bisenet/model.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from models.bisenet.resnet import Resnet18
|
11 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
12 |
+
|
13 |
+
|
14 |
+
class ConvBNReLU(nn.Module):
|
15 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
16 |
+
super(ConvBNReLU, self).__init__()
|
17 |
+
self.conv = nn.Conv2d(in_chan,
|
18 |
+
out_chan,
|
19 |
+
kernel_size = ks,
|
20 |
+
stride = stride,
|
21 |
+
padding = padding,
|
22 |
+
bias = False)
|
23 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
24 |
+
self.init_weight()
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.conv(x)
|
28 |
+
x = F.relu(self.bn(x))
|
29 |
+
return x
|
30 |
+
|
31 |
+
def init_weight(self):
|
32 |
+
for ly in self.children():
|
33 |
+
if isinstance(ly, nn.Conv2d):
|
34 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
35 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
36 |
+
|
37 |
+
class BiSeNetOutput(nn.Module):
|
38 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
39 |
+
super(BiSeNetOutput, self).__init__()
|
40 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
41 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
42 |
+
self.init_weight()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.conv(x)
|
46 |
+
x = self.conv_out(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
def init_weight(self):
|
50 |
+
for ly in self.children():
|
51 |
+
if isinstance(ly, nn.Conv2d):
|
52 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
53 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
54 |
+
|
55 |
+
def get_params(self):
|
56 |
+
wd_params, nowd_params = [], []
|
57 |
+
for name, module in self.named_modules():
|
58 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
59 |
+
wd_params.append(module.weight)
|
60 |
+
if not module.bias is None:
|
61 |
+
nowd_params.append(module.bias)
|
62 |
+
elif isinstance(module, nn.BatchNorm2d):
|
63 |
+
nowd_params += list(module.parameters())
|
64 |
+
return wd_params, nowd_params
|
65 |
+
|
66 |
+
|
67 |
+
class AttentionRefinementModule(nn.Module):
|
68 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
69 |
+
super(AttentionRefinementModule, self).__init__()
|
70 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
71 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
72 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
73 |
+
self.sigmoid_atten = nn.Sigmoid()
|
74 |
+
self.init_weight()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
feat = self.conv(x)
|
78 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
79 |
+
atten = self.conv_atten(atten)
|
80 |
+
atten = self.bn_atten(atten)
|
81 |
+
atten = self.sigmoid_atten(atten)
|
82 |
+
out = torch.mul(feat, atten)
|
83 |
+
return out
|
84 |
+
|
85 |
+
def init_weight(self):
|
86 |
+
for ly in self.children():
|
87 |
+
if isinstance(ly, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
89 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
90 |
+
|
91 |
+
|
92 |
+
class ContextPath(nn.Module):
|
93 |
+
def __init__(self, *args, **kwargs):
|
94 |
+
super(ContextPath, self).__init__()
|
95 |
+
self.resnet = Resnet18()
|
96 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
97 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
98 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
99 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
100 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
101 |
+
|
102 |
+
self.init_weight()
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
H0, W0 = x.size()[2:]
|
106 |
+
feat8, feat16, feat32 = self.resnet(x)
|
107 |
+
H8, W8 = feat8.size()[2:]
|
108 |
+
H16, W16 = feat16.size()[2:]
|
109 |
+
H32, W32 = feat32.size()[2:]
|
110 |
+
|
111 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
112 |
+
avg = self.conv_avg(avg)
|
113 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
114 |
+
|
115 |
+
feat32_arm = self.arm32(feat32)
|
116 |
+
feat32_sum = feat32_arm + avg_up
|
117 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
118 |
+
feat32_up = self.conv_head32(feat32_up)
|
119 |
+
|
120 |
+
feat16_arm = self.arm16(feat16)
|
121 |
+
feat16_sum = feat16_arm + feat32_up
|
122 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
123 |
+
feat16_up = self.conv_head16(feat16_up)
|
124 |
+
|
125 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
126 |
+
|
127 |
+
def init_weight(self):
|
128 |
+
for ly in self.children():
|
129 |
+
if isinstance(ly, nn.Conv2d):
|
130 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
131 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
132 |
+
|
133 |
+
def get_params(self):
|
134 |
+
wd_params, nowd_params = [], []
|
135 |
+
for name, module in self.named_modules():
|
136 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
137 |
+
wd_params.append(module.weight)
|
138 |
+
if not module.bias is None:
|
139 |
+
nowd_params.append(module.bias)
|
140 |
+
elif isinstance(module, nn.BatchNorm2d):
|
141 |
+
nowd_params += list(module.parameters())
|
142 |
+
return wd_params, nowd_params
|
143 |
+
|
144 |
+
|
145 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
146 |
+
class SpatialPath(nn.Module):
|
147 |
+
def __init__(self, *args, **kwargs):
|
148 |
+
super(SpatialPath, self).__init__()
|
149 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
150 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
151 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
152 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
153 |
+
self.init_weight()
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
feat = self.conv1(x)
|
157 |
+
feat = self.conv2(feat)
|
158 |
+
feat = self.conv3(feat)
|
159 |
+
feat = self.conv_out(feat)
|
160 |
+
return feat
|
161 |
+
|
162 |
+
def init_weight(self):
|
163 |
+
for ly in self.children():
|
164 |
+
if isinstance(ly, nn.Conv2d):
|
165 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
166 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
167 |
+
|
168 |
+
def get_params(self):
|
169 |
+
wd_params, nowd_params = [], []
|
170 |
+
for name, module in self.named_modules():
|
171 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
172 |
+
wd_params.append(module.weight)
|
173 |
+
if not module.bias is None:
|
174 |
+
nowd_params.append(module.bias)
|
175 |
+
elif isinstance(module, nn.BatchNorm2d):
|
176 |
+
nowd_params += list(module.parameters())
|
177 |
+
return wd_params, nowd_params
|
178 |
+
|
179 |
+
|
180 |
+
class FeatureFusionModule(nn.Module):
|
181 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
182 |
+
super(FeatureFusionModule, self).__init__()
|
183 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
184 |
+
self.conv1 = nn.Conv2d(out_chan,
|
185 |
+
out_chan//4,
|
186 |
+
kernel_size = 1,
|
187 |
+
stride = 1,
|
188 |
+
padding = 0,
|
189 |
+
bias = False)
|
190 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
191 |
+
out_chan,
|
192 |
+
kernel_size = 1,
|
193 |
+
stride = 1,
|
194 |
+
padding = 0,
|
195 |
+
bias = False)
|
196 |
+
self.relu = nn.ReLU(inplace=True)
|
197 |
+
self.sigmoid = nn.Sigmoid()
|
198 |
+
self.init_weight()
|
199 |
+
|
200 |
+
def forward(self, fsp, fcp):
|
201 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
202 |
+
feat = self.convblk(fcat)
|
203 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
204 |
+
atten = self.conv1(atten)
|
205 |
+
atten = self.relu(atten)
|
206 |
+
atten = self.conv2(atten)
|
207 |
+
atten = self.sigmoid(atten)
|
208 |
+
feat_atten = torch.mul(feat, atten)
|
209 |
+
feat_out = feat_atten + feat
|
210 |
+
return feat_out
|
211 |
+
|
212 |
+
def init_weight(self):
|
213 |
+
for ly in self.children():
|
214 |
+
if isinstance(ly, nn.Conv2d):
|
215 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
216 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
217 |
+
|
218 |
+
def get_params(self):
|
219 |
+
wd_params, nowd_params = [], []
|
220 |
+
for name, module in self.named_modules():
|
221 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
222 |
+
wd_params.append(module.weight)
|
223 |
+
if not module.bias is None:
|
224 |
+
nowd_params.append(module.bias)
|
225 |
+
elif isinstance(module, nn.BatchNorm2d):
|
226 |
+
nowd_params += list(module.parameters())
|
227 |
+
return wd_params, nowd_params
|
228 |
+
|
229 |
+
|
230 |
+
class BiSeNet(nn.Module):
|
231 |
+
def __init__(self, n_classes, *args, **kwargs):
|
232 |
+
super(BiSeNet, self).__init__()
|
233 |
+
self.cp = ContextPath()
|
234 |
+
## here self.sp is deleted
|
235 |
+
self.ffm = FeatureFusionModule(256, 256)
|
236 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
237 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
238 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
239 |
+
self.init_weight()
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
H, W = x.size()[2:]
|
243 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
244 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
245 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
246 |
+
|
247 |
+
feat_out = self.conv_out(feat_fuse)
|
248 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
249 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
250 |
+
|
251 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
252 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
253 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
254 |
+
return feat_out, feat_out16, feat_out32
|
255 |
+
|
256 |
+
def init_weight(self):
|
257 |
+
for ly in self.children():
|
258 |
+
if isinstance(ly, nn.Conv2d):
|
259 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
260 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
261 |
+
|
262 |
+
def get_params(self):
|
263 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
264 |
+
for name, child in self.named_children():
|
265 |
+
child_wd_params, child_nowd_params = child.get_params()
|
266 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
267 |
+
lr_mul_wd_params += child_wd_params
|
268 |
+
lr_mul_nowd_params += child_nowd_params
|
269 |
+
else:
|
270 |
+
wd_params += child_wd_params
|
271 |
+
nowd_params += child_nowd_params
|
272 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
273 |
+
|
274 |
+
|
275 |
+
if __name__ == "__main__":
|
276 |
+
net = BiSeNet(19)
|
277 |
+
net.cuda()
|
278 |
+
net.eval()
|
279 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
280 |
+
out, out16, out32 = net(in_ten)
|
281 |
+
print(out.shape)
|
282 |
+
|
283 |
+
net.get_params()
|
models/bisenet/resnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as modelzoo
|
8 |
+
|
9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
10 |
+
|
11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
17 |
+
padding=1, bias=False)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
if in_chan != out_chan or stride != 1:
|
30 |
+
self.downsample = nn.Sequential(
|
31 |
+
nn.Conv2d(in_chan, out_chan,
|
32 |
+
kernel_size=1, stride=stride, bias=False),
|
33 |
+
nn.BatchNorm2d(out_chan),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = self.conv1(x)
|
38 |
+
residual = F.relu(self.bn1(residual))
|
39 |
+
residual = self.conv2(residual)
|
40 |
+
residual = self.bn2(residual)
|
41 |
+
|
42 |
+
shortcut = x
|
43 |
+
if self.downsample is not None:
|
44 |
+
shortcut = self.downsample(x)
|
45 |
+
|
46 |
+
out = shortcut + residual
|
47 |
+
out = self.relu(out)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
53 |
+
for i in range(bnum-1):
|
54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
|
58 |
+
class Resnet18(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(Resnet18, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.conv1(x)
|
73 |
+
x = F.relu(self.bn1(x))
|
74 |
+
x = self.maxpool(x)
|
75 |
+
|
76 |
+
x = self.layer1(x)
|
77 |
+
feat8 = self.layer2(x) # 1/8
|
78 |
+
feat16 = self.layer3(feat8) # 1/16
|
79 |
+
feat32 = self.layer4(feat16) # 1/32
|
80 |
+
return feat8, feat16, feat32
|
81 |
+
|
82 |
+
def init_weight(self):
|
83 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
84 |
+
self_state_dict = self.state_dict()
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
if 'fc' in k: continue
|
87 |
+
self_state_dict.update({k: v})
|
88 |
+
self.load_state_dict(self_state_dict)
|
89 |
+
|
90 |
+
def get_params(self):
|
91 |
+
wd_params, nowd_params = [], []
|
92 |
+
for name, module in self.named_modules():
|
93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
94 |
+
wd_params.append(module.weight)
|
95 |
+
if not module.bias is None:
|
96 |
+
nowd_params.append(module.bias)
|
97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
98 |
+
nowd_params += list(module.parameters())
|
99 |
+
return wd_params, nowd_params
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
net = Resnet18()
|
104 |
+
x = torch.randn(16, 3, 224, 224)
|
105 |
+
out = net(x)
|
106 |
+
print(out[0].size())
|
107 |
+
print(out[1].size())
|
108 |
+
print(out[2].size())
|
109 |
+
net.get_params()
|
models/encoders/__init__.py
ADDED
File without changes
|
models/encoders/helpers.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
4 |
+
|
5 |
+
"""
|
6 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
class Flatten(Module):
|
11 |
+
def forward(self, input):
|
12 |
+
return input.view(input.size(0), -1)
|
13 |
+
|
14 |
+
|
15 |
+
def l2_norm(input, axis=1):
|
16 |
+
norm = torch.norm(input, 2, axis, True)
|
17 |
+
output = torch.div(input, norm)
|
18 |
+
return output
|
19 |
+
|
20 |
+
|
21 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
22 |
+
""" A named tuple describing a ResNet block. """
|
23 |
+
|
24 |
+
|
25 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
26 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
27 |
+
|
28 |
+
|
29 |
+
def get_blocks(num_layers):
|
30 |
+
if num_layers == 50:
|
31 |
+
blocks = [
|
32 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
33 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
34 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
35 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
36 |
+
]
|
37 |
+
elif num_layers == 100:
|
38 |
+
blocks = [
|
39 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
40 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
41 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
42 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
43 |
+
]
|
44 |
+
elif num_layers == 152:
|
45 |
+
blocks = [
|
46 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
47 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
48 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
49 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
50 |
+
]
|
51 |
+
else:
|
52 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
53 |
+
return blocks
|
54 |
+
|
55 |
+
|
56 |
+
class SEModule(Module):
|
57 |
+
def __init__(self, channels, reduction):
|
58 |
+
super(SEModule, self).__init__()
|
59 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
60 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
61 |
+
self.relu = ReLU(inplace=True)
|
62 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
63 |
+
self.sigmoid = Sigmoid()
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
module_input = x
|
67 |
+
x = self.avg_pool(x)
|
68 |
+
x = self.fc1(x)
|
69 |
+
x = self.relu(x)
|
70 |
+
x = self.fc2(x)
|
71 |
+
x = self.sigmoid(x)
|
72 |
+
return module_input * x
|
73 |
+
|
74 |
+
|
75 |
+
class bottleneck_IR(Module):
|
76 |
+
def __init__(self, in_channel, depth, stride):
|
77 |
+
super(bottleneck_IR, self).__init__()
|
78 |
+
if in_channel == depth:
|
79 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
80 |
+
else:
|
81 |
+
self.shortcut_layer = Sequential(
|
82 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
83 |
+
BatchNorm2d(depth)
|
84 |
+
)
|
85 |
+
self.res_layer = Sequential(
|
86 |
+
BatchNorm2d(in_channel),
|
87 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
88 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
shortcut = self.shortcut_layer(x)
|
93 |
+
res = self.res_layer(x)
|
94 |
+
return res + shortcut
|
95 |
+
|
96 |
+
|
97 |
+
class bottleneck_IR_SE(Module):
|
98 |
+
def __init__(self, in_channel, depth, stride):
|
99 |
+
super(bottleneck_IR_SE, self).__init__()
|
100 |
+
if in_channel == depth:
|
101 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
102 |
+
else:
|
103 |
+
self.shortcut_layer = Sequential(
|
104 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
105 |
+
BatchNorm2d(depth)
|
106 |
+
)
|
107 |
+
self.res_layer = Sequential(
|
108 |
+
BatchNorm2d(in_channel),
|
109 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
110 |
+
PReLU(depth),
|
111 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
112 |
+
BatchNorm2d(depth),
|
113 |
+
SEModule(depth, 16)
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
shortcut = self.shortcut_layer(x)
|
118 |
+
res = self.res_layer(x)
|
119 |
+
return res + shortcut
|
models/encoders/model_irse.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
2 |
+
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
3 |
+
|
4 |
+
"""
|
5 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Backbone(Module):
|
10 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
11 |
+
super(Backbone, self).__init__()
|
12 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
13 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
14 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
15 |
+
blocks = get_blocks(num_layers)
|
16 |
+
if mode == 'ir':
|
17 |
+
unit_module = bottleneck_IR
|
18 |
+
elif mode == 'ir_se':
|
19 |
+
unit_module = bottleneck_IR_SE
|
20 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
21 |
+
BatchNorm2d(64),
|
22 |
+
PReLU(64))
|
23 |
+
if input_size == 112:
|
24 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
25 |
+
Dropout(drop_ratio),
|
26 |
+
Flatten(),
|
27 |
+
Linear(512 * 7 * 7, 512),
|
28 |
+
BatchNorm1d(512, affine=affine))
|
29 |
+
else:
|
30 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
31 |
+
Dropout(drop_ratio),
|
32 |
+
Flatten(),
|
33 |
+
Linear(512 * 14 * 14, 512),
|
34 |
+
BatchNorm1d(512, affine=affine))
|
35 |
+
|
36 |
+
modules = []
|
37 |
+
for block in blocks:
|
38 |
+
for bottleneck in block:
|
39 |
+
modules.append(unit_module(bottleneck.in_channel,
|
40 |
+
bottleneck.depth,
|
41 |
+
bottleneck.stride))
|
42 |
+
self.body = Sequential(*modules)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.input_layer(x)
|
46 |
+
x = self.body(x)
|
47 |
+
x = self.output_layer(x)
|
48 |
+
return l2_norm(x)
|
49 |
+
|
50 |
+
|
51 |
+
def IR_50(input_size):
|
52 |
+
"""Constructs a ir-50 model."""
|
53 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
54 |
+
return model
|
55 |
+
|
56 |
+
|
57 |
+
def IR_101(input_size):
|
58 |
+
"""Constructs a ir-101 model."""
|
59 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def IR_152(input_size):
|
64 |
+
"""Constructs a ir-152 model."""
|
65 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
66 |
+
return model
|
67 |
+
|
68 |
+
|
69 |
+
def IR_SE_50(input_size):
|
70 |
+
"""Constructs a ir_se-50 model."""
|
71 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
def IR_SE_101(input_size):
|
76 |
+
"""Constructs a ir_se-101 model."""
|
77 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
78 |
+
return model
|
79 |
+
|
80 |
+
|
81 |
+
def IR_SE_152(input_size):
|
82 |
+
"""Constructs a ir_se-152 model."""
|
83 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
84 |
+
return model
|
models/encoders/psp_encoders.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import Linear, Conv2d, BatchNorm2d, PReLU, Sequential, Module
|
6 |
+
|
7 |
+
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE
|
8 |
+
from models.stylegan2.model import EqualLinear
|
9 |
+
|
10 |
+
|
11 |
+
class GradualStyleBlock(Module):
|
12 |
+
def __init__(self, in_c, out_c, spatial, max_pooling=False):
|
13 |
+
super(GradualStyleBlock, self).__init__()
|
14 |
+
self.out_c = out_c
|
15 |
+
self.spatial = spatial
|
16 |
+
self.max_pooling = max_pooling
|
17 |
+
num_pools = int(np.log2(spatial))
|
18 |
+
modules = []
|
19 |
+
modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
|
20 |
+
nn.LeakyReLU()]
|
21 |
+
for i in range(num_pools - 1):
|
22 |
+
modules += [
|
23 |
+
Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
|
24 |
+
nn.LeakyReLU()
|
25 |
+
]
|
26 |
+
self.convs = nn.Sequential(*modules)
|
27 |
+
self.linear = EqualLinear(out_c, out_c, lr_mul=1)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.convs(x)
|
31 |
+
# To make E accept more general H*W images, we add global average pooling to
|
32 |
+
# resize all features to 1*1*512 before mapping to latent codes
|
33 |
+
if self.max_pooling:
|
34 |
+
x = F.adaptive_max_pool2d(x, 1) ##### modified
|
35 |
+
else:
|
36 |
+
x = F.adaptive_avg_pool2d(x, 1) ##### modified
|
37 |
+
x = x.view(-1, self.out_c)
|
38 |
+
x = self.linear(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
class AdaptiveInstanceNorm(nn.Module):
|
42 |
+
def __init__(self, fin, style_dim=512):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
self.norm = nn.InstanceNorm2d(fin, affine=False)
|
46 |
+
self.style = nn.Linear(style_dim, fin * 2)
|
47 |
+
|
48 |
+
self.style.bias.data[:fin] = 1
|
49 |
+
self.style.bias.data[fin:] = 0
|
50 |
+
|
51 |
+
def forward(self, input, style):
|
52 |
+
style = self.style(style).unsqueeze(2).unsqueeze(3)
|
53 |
+
gamma, beta = style.chunk(2, 1)
|
54 |
+
out = self.norm(input)
|
55 |
+
out = gamma * out + beta
|
56 |
+
return out
|
57 |
+
|
58 |
+
|
59 |
+
class FusionLayer(Module): ##### modified
|
60 |
+
def __init__(self, inchannel, outchannel, use_skip_torgb=True, use_att=0):
|
61 |
+
super(FusionLayer, self).__init__()
|
62 |
+
|
63 |
+
self.transform = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=1, padding=1),
|
64 |
+
nn.LeakyReLU())
|
65 |
+
self.fusion_out = nn.Conv2d(outchannel*2, outchannel, kernel_size=3, stride=1, padding=1)
|
66 |
+
self.fusion_out.weight.data *= 0.01
|
67 |
+
self.fusion_out.weight[:,0:outchannel,1,1].data += torch.eye(outchannel)
|
68 |
+
|
69 |
+
self.use_skip_torgb = use_skip_torgb
|
70 |
+
if use_skip_torgb:
|
71 |
+
self.fusion_skip = nn.Conv2d(3+outchannel, 3, kernel_size=3, stride=1, padding=1)
|
72 |
+
self.fusion_skip.weight.data *= 0.01
|
73 |
+
self.fusion_skip.weight[:,0:3,1,1].data += torch.eye(3)
|
74 |
+
|
75 |
+
self.use_att = use_att
|
76 |
+
if use_att:
|
77 |
+
modules = []
|
78 |
+
modules.append(nn.Linear(512, outchannel))
|
79 |
+
for _ in range(use_att):
|
80 |
+
modules.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
81 |
+
modules.append(nn.Linear(outchannel, outchannel))
|
82 |
+
modules.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
83 |
+
self.linear = Sequential(*modules)
|
84 |
+
self.norm = AdaptiveInstanceNorm(outchannel*2, outchannel)
|
85 |
+
self.conv = nn.Conv2d(outchannel*2, 1, 3, 1, 1, bias=True)
|
86 |
+
|
87 |
+
def forward(self, feat, out, skip, editing_w=None):
|
88 |
+
x = self.transform(feat)
|
89 |
+
# similar to VToonify, use editing vector as condition
|
90 |
+
# fuse encoder feature and decoder feature with a predicted attention mask m_E
|
91 |
+
# if self.use_att = False, just fuse them with a simple conv layer
|
92 |
+
if self.use_att and editing_w is not None:
|
93 |
+
label = self.linear(editing_w)
|
94 |
+
m_E = (F.relu(self.conv(self.norm(torch.cat([out, abs(out-x)], dim=1), label)))).tanh()
|
95 |
+
x = x * m_E
|
96 |
+
out = self.fusion_out(torch.cat((out, x), dim=1))
|
97 |
+
if self.use_skip_torgb:
|
98 |
+
skip = self.fusion_skip(torch.cat((skip, x), dim=1))
|
99 |
+
return out, skip
|
100 |
+
|
101 |
+
|
102 |
+
class ResnetBlock(nn.Module):
|
103 |
+
def __init__(self, dim):
|
104 |
+
super(ResnetBlock, self).__init__()
|
105 |
+
|
106 |
+
self.conv_block = nn.Sequential(Conv2d(dim, dim, 3, 1, 1),
|
107 |
+
nn.LeakyReLU(),
|
108 |
+
Conv2d(dim, dim, 3, 1, 1))
|
109 |
+
self.relu = nn.LeakyReLU()
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
out = x + self.conv_block(x)
|
113 |
+
return self.relu(out)
|
114 |
+
|
115 |
+
# trainable light-weight translation network T
|
116 |
+
# for sketch/mask-to-face translation,
|
117 |
+
# we add a trainable T to map y to an intermediate domain where E can more easily extract features.
|
118 |
+
class ResnetGenerator(nn.Module):
|
119 |
+
def __init__(self, in_channel=19, res_num=2):
|
120 |
+
super(ResnetGenerator, self).__init__()
|
121 |
+
|
122 |
+
modules = []
|
123 |
+
modules.append(Conv2d(in_channel, 16, 3, 2, 1))
|
124 |
+
modules.append(nn.LeakyReLU())
|
125 |
+
modules.append(Conv2d(16, 16, 3, 2, 1))
|
126 |
+
modules.append(nn.LeakyReLU())
|
127 |
+
for _ in range(res_num):
|
128 |
+
modules.append(ResnetBlock(16))
|
129 |
+
for _ in range(2):
|
130 |
+
modules.append(nn.ConvTranspose2d(16, 16, 3, 2, 1, output_padding=1))
|
131 |
+
modules.append(nn.LeakyReLU())
|
132 |
+
modules.append(Conv2d(16, 64, 3, 1, 1, bias=False))
|
133 |
+
modules.append(BatchNorm2d(64))
|
134 |
+
modules.append(PReLU(64))
|
135 |
+
self.model = Sequential(*modules)
|
136 |
+
|
137 |
+
def forward(self, input):
|
138 |
+
return self.model(input)
|
139 |
+
|
140 |
+
class GradualStyleEncoder(Module):
|
141 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
142 |
+
super(GradualStyleEncoder, self).__init__()
|
143 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
144 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
145 |
+
blocks = get_blocks(num_layers)
|
146 |
+
if mode == 'ir':
|
147 |
+
unit_module = bottleneck_IR
|
148 |
+
elif mode == 'ir_se':
|
149 |
+
unit_module = bottleneck_IR_SE
|
150 |
+
|
151 |
+
# for sketch/mask-to-face translation, add a new network T
|
152 |
+
if opts.input_nc != 3:
|
153 |
+
self.input_label_layer = ResnetGenerator(opts.input_nc, opts.res_num)
|
154 |
+
|
155 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
156 |
+
BatchNorm2d(64),
|
157 |
+
PReLU(64))
|
158 |
+
modules = []
|
159 |
+
for block in blocks:
|
160 |
+
for bottleneck in block:
|
161 |
+
modules.append(unit_module(bottleneck.in_channel,
|
162 |
+
bottleneck.depth,
|
163 |
+
bottleneck.stride))
|
164 |
+
self.body = Sequential(*modules)
|
165 |
+
|
166 |
+
self.styles = nn.ModuleList()
|
167 |
+
self.style_count = opts.n_styles
|
168 |
+
self.coarse_ind = 3
|
169 |
+
self.middle_ind = 7
|
170 |
+
for i in range(self.style_count):
|
171 |
+
if i < self.coarse_ind:
|
172 |
+
style = GradualStyleBlock(512, 512, 16, 'max_pooling' in opts and opts.max_pooling)
|
173 |
+
elif i < self.middle_ind:
|
174 |
+
style = GradualStyleBlock(512, 512, 32, 'max_pooling' in opts and opts.max_pooling)
|
175 |
+
else:
|
176 |
+
style = GradualStyleBlock(512, 512, 64, 'max_pooling' in opts and opts.max_pooling)
|
177 |
+
self.styles.append(style)
|
178 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
179 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
180 |
+
|
181 |
+
# we concatenate pSp features in the middle layers and
|
182 |
+
# add a convolution layer to map the concatenated features to the first-layer input feature f of G.
|
183 |
+
self.featlayer = nn.Conv2d(768, 512, kernel_size=1, stride=1, padding=0) ##### modified
|
184 |
+
self.skiplayer = nn.Conv2d(768, 3, kernel_size=1, stride=1, padding=0) ##### modified
|
185 |
+
|
186 |
+
# skip connection
|
187 |
+
if 'use_skip' in opts and opts.use_skip: ##### modified
|
188 |
+
self.fusion = nn.ModuleList()
|
189 |
+
channels = [[256,512], [256,512], [256,512], [256,512], [128,512], [64,256], [64,128]]
|
190 |
+
# opts.skip_max_layer: how many layers are skipped to the decoder
|
191 |
+
for inc, outc in channels[:max(1, min(7, opts.skip_max_layer))]: # from 4 to 256
|
192 |
+
self.fusion.append(FusionLayer(inc, outc, opts.use_skip_torgb, opts.use_att))
|
193 |
+
|
194 |
+
def _upsample_add(self, x, y):
|
195 |
+
'''Upsample and add two feature maps.
|
196 |
+
Args:
|
197 |
+
x: (Variable) top feature map to be upsampled.
|
198 |
+
y: (Variable) lateral feature map.
|
199 |
+
Returns:
|
200 |
+
(Variable) added feature map.
|
201 |
+
Note in PyTorch, when input size is odd, the upsampled feature map
|
202 |
+
with `F.upsample(..., scale_factor=2, mode='nearest')`
|
203 |
+
maybe not equal to the lateral feature map size.
|
204 |
+
e.g.
|
205 |
+
original input size: [N,_,15,15] ->
|
206 |
+
conv2d feature map size: [N,_,8,8] ->
|
207 |
+
upsampled feature map size: [N,_,16,16]
|
208 |
+
So we choose bilinear upsample which supports arbitrary output sizes.
|
209 |
+
'''
|
210 |
+
_, _, H, W = y.size()
|
211 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
|
212 |
+
|
213 |
+
# return_feat: return f
|
214 |
+
# return_full: return f and the skipped encoder features
|
215 |
+
# return [out, feats]
|
216 |
+
# out is the style latent code w+
|
217 |
+
# feats[0] is f for the 1st conv layer, feats[1] is f for the 1st torgb layer
|
218 |
+
# feats[2-8] is the skipped encoder features
|
219 |
+
def forward(self, x, return_feat=False, return_full=False): ##### modified
|
220 |
+
if x.shape[1] != 3:
|
221 |
+
x = self.input_label_layer(x)
|
222 |
+
else:
|
223 |
+
x = self.input_layer(x)
|
224 |
+
c256 = x ##### modified
|
225 |
+
|
226 |
+
latents = []
|
227 |
+
modulelist = list(self.body._modules.values())
|
228 |
+
for i, l in enumerate(modulelist):
|
229 |
+
x = l(x)
|
230 |
+
if i == 2: ##### modified
|
231 |
+
c128 = x
|
232 |
+
elif i == 6:
|
233 |
+
c1 = x
|
234 |
+
elif i == 10: ##### modified
|
235 |
+
c21 = x ##### modified
|
236 |
+
elif i == 15: ##### modified
|
237 |
+
c22 = x ##### modified
|
238 |
+
elif i == 20:
|
239 |
+
c2 = x
|
240 |
+
elif i == 23:
|
241 |
+
c3 = x
|
242 |
+
|
243 |
+
for j in range(self.coarse_ind):
|
244 |
+
latents.append(self.styles[j](c3))
|
245 |
+
|
246 |
+
p2 = self._upsample_add(c3, self.latlayer1(c2))
|
247 |
+
for j in range(self.coarse_ind, self.middle_ind):
|
248 |
+
latents.append(self.styles[j](p2))
|
249 |
+
|
250 |
+
p1 = self._upsample_add(p2, self.latlayer2(c1))
|
251 |
+
for j in range(self.middle_ind, self.style_count):
|
252 |
+
latents.append(self.styles[j](p1))
|
253 |
+
|
254 |
+
out = torch.stack(latents, dim=1)
|
255 |
+
|
256 |
+
if not return_feat:
|
257 |
+
return out
|
258 |
+
|
259 |
+
feats = [self.featlayer(torch.cat((c21, c22, c2), dim=1)), self.skiplayer(torch.cat((c21, c22, c2), dim=1))]
|
260 |
+
|
261 |
+
if return_full: ##### modified
|
262 |
+
feats += [c2, c2, c22, c21, c1, c128, c256]
|
263 |
+
|
264 |
+
return out, feats
|
265 |
+
|
266 |
+
|
267 |
+
# only compute the first-layer feature f
|
268 |
+
# E_F in the paper
|
269 |
+
def get_feat(self, x): ##### modified
|
270 |
+
# for sketch/mask-to-face translation
|
271 |
+
# use a trainable light-weight translation network T
|
272 |
+
if x.shape[1] != 3:
|
273 |
+
x = self.input_label_layer(x)
|
274 |
+
else:
|
275 |
+
x = self.input_layer(x)
|
276 |
+
|
277 |
+
latents = []
|
278 |
+
modulelist = list(self.body._modules.values())
|
279 |
+
for i, l in enumerate(modulelist):
|
280 |
+
x = l(x)
|
281 |
+
if i == 10: ##### modified
|
282 |
+
c21 = x ##### modified
|
283 |
+
elif i == 15: ##### modified
|
284 |
+
c22 = x ##### modified
|
285 |
+
elif i == 20:
|
286 |
+
c2 = x
|
287 |
+
break
|
288 |
+
return self.featlayer(torch.cat((c21, c22, c2), dim=1))
|
289 |
+
|
290 |
+
class BackboneEncoderUsingLastLayerIntoW(Module):
|
291 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
292 |
+
super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
|
293 |
+
print('Using BackboneEncoderUsingLastLayerIntoW')
|
294 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
295 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
296 |
+
blocks = get_blocks(num_layers)
|
297 |
+
if mode == 'ir':
|
298 |
+
unit_module = bottleneck_IR
|
299 |
+
elif mode == 'ir_se':
|
300 |
+
unit_module = bottleneck_IR_SE
|
301 |
+
self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
|
302 |
+
BatchNorm2d(64),
|
303 |
+
PReLU(64))
|
304 |
+
self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
305 |
+
self.linear = EqualLinear(512, 512, lr_mul=1)
|
306 |
+
modules = []
|
307 |
+
for block in blocks:
|
308 |
+
for bottleneck in block:
|
309 |
+
modules.append(unit_module(bottleneck.in_channel,
|
310 |
+
bottleneck.depth,
|
311 |
+
bottleneck.stride))
|
312 |
+
self.body = Sequential(*modules)
|
313 |
+
|
314 |
+
def forward(self, x):
|
315 |
+
x = self.input_layer(x)
|
316 |
+
x = self.body(x)
|
317 |
+
x = self.output_pool(x)
|
318 |
+
x = x.view(-1, 512)
|
319 |
+
x = self.linear(x)
|
320 |
+
return x
|
321 |
+
|
322 |
+
|
323 |
+
class BackboneEncoderUsingLastLayerIntoWPlus(Module):
|
324 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
325 |
+
super(BackboneEncoderUsingLastLayerIntoWPlus, self).__init__()
|
326 |
+
print('Using BackboneEncoderUsingLastLayerIntoWPlus')
|
327 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
328 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
329 |
+
blocks = get_blocks(num_layers)
|
330 |
+
if mode == 'ir':
|
331 |
+
unit_module = bottleneck_IR
|
332 |
+
elif mode == 'ir_se':
|
333 |
+
unit_module = bottleneck_IR_SE
|
334 |
+
self.n_styles = opts.n_styles
|
335 |
+
self.input_layer = Sequential(Conv2d(opts.input_nc, 64, (3, 3), 1, 1, bias=False),
|
336 |
+
BatchNorm2d(64),
|
337 |
+
PReLU(64))
|
338 |
+
self.output_layer_2 = Sequential(BatchNorm2d(512),
|
339 |
+
torch.nn.AdaptiveAvgPool2d((7, 7)),
|
340 |
+
Flatten(),
|
341 |
+
Linear(512 * 7 * 7, 512))
|
342 |
+
self.linear = EqualLinear(512, 512 * self.n_styles, lr_mul=1)
|
343 |
+
modules = []
|
344 |
+
for block in blocks:
|
345 |
+
for bottleneck in block:
|
346 |
+
modules.append(unit_module(bottleneck.in_channel,
|
347 |
+
bottleneck.depth,
|
348 |
+
bottleneck.stride))
|
349 |
+
self.body = Sequential(*modules)
|
350 |
+
|
351 |
+
def forward(self, x):
|
352 |
+
x = self.input_layer(x)
|
353 |
+
x = self.body(x)
|
354 |
+
x = self.output_layer_2(x)
|
355 |
+
x = self.linear(x)
|
356 |
+
x = x.view(-1, self.n_styles, 512)
|
357 |
+
return x
|
models/mtcnn/__init__.py
ADDED
File without changes
|
models/mtcnn/mtcnn.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from models.mtcnn.mtcnn_pytorch.src.get_nets import PNet, RNet, ONet
|
5 |
+
from models.mtcnn.mtcnn_pytorch.src.box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
|
6 |
+
from models.mtcnn.mtcnn_pytorch.src.first_stage import run_first_stage
|
7 |
+
from models.mtcnn.mtcnn_pytorch.src.align_trans import get_reference_facial_points, warp_and_crop_face
|
8 |
+
|
9 |
+
device = 'cuda:0'
|
10 |
+
|
11 |
+
|
12 |
+
class MTCNN():
|
13 |
+
def __init__(self):
|
14 |
+
print(device)
|
15 |
+
self.pnet = PNet().to(device)
|
16 |
+
self.rnet = RNet().to(device)
|
17 |
+
self.onet = ONet().to(device)
|
18 |
+
self.pnet.eval()
|
19 |
+
self.rnet.eval()
|
20 |
+
self.onet.eval()
|
21 |
+
self.refrence = get_reference_facial_points(default_square=True)
|
22 |
+
|
23 |
+
def align(self, img):
|
24 |
+
_, landmarks = self.detect_faces(img)
|
25 |
+
if len(landmarks) == 0:
|
26 |
+
return None, None
|
27 |
+
facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)]
|
28 |
+
warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
|
29 |
+
return Image.fromarray(warped_face), tfm
|
30 |
+
|
31 |
+
def align_multi(self, img, limit=None, min_face_size=30.0):
|
32 |
+
boxes, landmarks = self.detect_faces(img, min_face_size)
|
33 |
+
if limit:
|
34 |
+
boxes = boxes[:limit]
|
35 |
+
landmarks = landmarks[:limit]
|
36 |
+
faces = []
|
37 |
+
tfms = []
|
38 |
+
for landmark in landmarks:
|
39 |
+
facial5points = [[landmark[j], landmark[j + 5]] for j in range(5)]
|
40 |
+
warped_face, tfm = warp_and_crop_face(np.array(img), facial5points, self.refrence, crop_size=(112, 112))
|
41 |
+
faces.append(Image.fromarray(warped_face))
|
42 |
+
tfms.append(tfm)
|
43 |
+
return boxes, faces, tfms
|
44 |
+
|
45 |
+
def detect_faces(self, image, min_face_size=20.0,
|
46 |
+
thresholds=[0.15, 0.25, 0.35],
|
47 |
+
nms_thresholds=[0.7, 0.7, 0.7]):
|
48 |
+
"""
|
49 |
+
Arguments:
|
50 |
+
image: an instance of PIL.Image.
|
51 |
+
min_face_size: a float number.
|
52 |
+
thresholds: a list of length 3.
|
53 |
+
nms_thresholds: a list of length 3.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
|
57 |
+
bounding boxes and facial landmarks.
|
58 |
+
"""
|
59 |
+
|
60 |
+
# BUILD AN IMAGE PYRAMID
|
61 |
+
width, height = image.size
|
62 |
+
min_length = min(height, width)
|
63 |
+
|
64 |
+
min_detection_size = 12
|
65 |
+
factor = 0.707 # sqrt(0.5)
|
66 |
+
|
67 |
+
# scales for scaling the image
|
68 |
+
scales = []
|
69 |
+
|
70 |
+
# scales the image so that
|
71 |
+
# minimum size that we can detect equals to
|
72 |
+
# minimum face size that we want to detect
|
73 |
+
m = min_detection_size / min_face_size
|
74 |
+
min_length *= m
|
75 |
+
|
76 |
+
factor_count = 0
|
77 |
+
while min_length > min_detection_size:
|
78 |
+
scales.append(m * factor ** factor_count)
|
79 |
+
min_length *= factor
|
80 |
+
factor_count += 1
|
81 |
+
|
82 |
+
# STAGE 1
|
83 |
+
|
84 |
+
# it will be returned
|
85 |
+
bounding_boxes = []
|
86 |
+
|
87 |
+
with torch.no_grad():
|
88 |
+
# run P-Net on different scales
|
89 |
+
for s in scales:
|
90 |
+
boxes = run_first_stage(image, self.pnet, scale=s, threshold=thresholds[0])
|
91 |
+
bounding_boxes.append(boxes)
|
92 |
+
|
93 |
+
# collect boxes (and offsets, and scores) from different scales
|
94 |
+
bounding_boxes = [i for i in bounding_boxes if i is not None]
|
95 |
+
bounding_boxes = np.vstack(bounding_boxes)
|
96 |
+
|
97 |
+
keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
|
98 |
+
bounding_boxes = bounding_boxes[keep]
|
99 |
+
|
100 |
+
# use offsets predicted by pnet to transform bounding boxes
|
101 |
+
bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
|
102 |
+
# shape [n_boxes, 5]
|
103 |
+
|
104 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
105 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
106 |
+
|
107 |
+
# STAGE 2
|
108 |
+
|
109 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=24)
|
110 |
+
img_boxes = torch.FloatTensor(img_boxes).to(device)
|
111 |
+
|
112 |
+
output = self.rnet(img_boxes)
|
113 |
+
offsets = output[0].cpu().data.numpy() # shape [n_boxes, 4]
|
114 |
+
probs = output[1].cpu().data.numpy() # shape [n_boxes, 2]
|
115 |
+
|
116 |
+
keep = np.where(probs[:, 1] > thresholds[1])[0]
|
117 |
+
bounding_boxes = bounding_boxes[keep]
|
118 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
119 |
+
offsets = offsets[keep]
|
120 |
+
|
121 |
+
keep = nms(bounding_boxes, nms_thresholds[1])
|
122 |
+
bounding_boxes = bounding_boxes[keep]
|
123 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
|
124 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
125 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
126 |
+
|
127 |
+
# STAGE 3
|
128 |
+
|
129 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=48)
|
130 |
+
if len(img_boxes) == 0:
|
131 |
+
return [], []
|
132 |
+
img_boxes = torch.FloatTensor(img_boxes).to(device)
|
133 |
+
output = self.onet(img_boxes)
|
134 |
+
landmarks = output[0].cpu().data.numpy() # shape [n_boxes, 10]
|
135 |
+
offsets = output[1].cpu().data.numpy() # shape [n_boxes, 4]
|
136 |
+
probs = output[2].cpu().data.numpy() # shape [n_boxes, 2]
|
137 |
+
|
138 |
+
keep = np.where(probs[:, 1] > thresholds[2])[0]
|
139 |
+
bounding_boxes = bounding_boxes[keep]
|
140 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
141 |
+
offsets = offsets[keep]
|
142 |
+
landmarks = landmarks[keep]
|
143 |
+
|
144 |
+
# compute landmark points
|
145 |
+
width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
|
146 |
+
height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
|
147 |
+
xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
|
148 |
+
landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
|
149 |
+
landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
|
150 |
+
|
151 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets)
|
152 |
+
keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
|
153 |
+
bounding_boxes = bounding_boxes[keep]
|
154 |
+
landmarks = landmarks[keep]
|
155 |
+
|
156 |
+
return bounding_boxes, landmarks
|
models/mtcnn/mtcnn_pytorch/__init__.py
ADDED
File without changes
|
models/mtcnn/mtcnn_pytorch/src/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .visualization_utils import show_bboxes
|
2 |
+
from .detector import detect_faces
|
models/mtcnn/mtcnn_pytorch/src/align_trans.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Mon Apr 24 15:43:29 2017
|
4 |
+
@author: zhaoy
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
# from scipy.linalg import lstsq
|
10 |
+
# from scipy.ndimage import geometric_transform # , map_coordinates
|
11 |
+
|
12 |
+
from models.mtcnn.mtcnn_pytorch.src.matlab_cp2tform import get_similarity_transform_for_cv2
|
13 |
+
|
14 |
+
# reference facial points, a list of coordinates (x,y)
|
15 |
+
REFERENCE_FACIAL_POINTS = [
|
16 |
+
[30.29459953, 51.69630051],
|
17 |
+
[65.53179932, 51.50139999],
|
18 |
+
[48.02519989, 71.73660278],
|
19 |
+
[33.54930115, 92.3655014],
|
20 |
+
[62.72990036, 92.20410156]
|
21 |
+
]
|
22 |
+
|
23 |
+
DEFAULT_CROP_SIZE = (96, 112)
|
24 |
+
|
25 |
+
|
26 |
+
class FaceWarpException(Exception):
|
27 |
+
def __str__(self):
|
28 |
+
return 'In File {}:{}'.format(
|
29 |
+
__file__, super.__str__(self))
|
30 |
+
|
31 |
+
|
32 |
+
def get_reference_facial_points(output_size=None,
|
33 |
+
inner_padding_factor=0.0,
|
34 |
+
outer_padding=(0, 0),
|
35 |
+
default_square=False):
|
36 |
+
"""
|
37 |
+
Function:
|
38 |
+
----------
|
39 |
+
get reference 5 key points according to crop settings:
|
40 |
+
0. Set default crop_size:
|
41 |
+
if default_square:
|
42 |
+
crop_size = (112, 112)
|
43 |
+
else:
|
44 |
+
crop_size = (96, 112)
|
45 |
+
1. Pad the crop_size by inner_padding_factor in each side;
|
46 |
+
2. Resize crop_size into (output_size - outer_padding*2),
|
47 |
+
pad into output_size with outer_padding;
|
48 |
+
3. Output reference_5point;
|
49 |
+
Parameters:
|
50 |
+
----------
|
51 |
+
@output_size: (w, h) or None
|
52 |
+
size of aligned face image
|
53 |
+
@inner_padding_factor: (w_factor, h_factor)
|
54 |
+
padding factor for inner (w, h)
|
55 |
+
@outer_padding: (w_pad, h_pad)
|
56 |
+
each row is a pair of coordinates (x, y)
|
57 |
+
@default_square: True or False
|
58 |
+
if True:
|
59 |
+
default crop_size = (112, 112)
|
60 |
+
else:
|
61 |
+
default crop_size = (96, 112);
|
62 |
+
!!! make sure, if output_size is not None:
|
63 |
+
(output_size - outer_padding)
|
64 |
+
= some_scale * (default crop_size * (1.0 + inner_padding_factor))
|
65 |
+
Returns:
|
66 |
+
----------
|
67 |
+
@reference_5point: 5x2 np.array
|
68 |
+
each row is a pair of transformed coordinates (x, y)
|
69 |
+
"""
|
70 |
+
# print('\n===> get_reference_facial_points():')
|
71 |
+
|
72 |
+
# print('---> Params:')
|
73 |
+
# print(' output_size: ', output_size)
|
74 |
+
# print(' inner_padding_factor: ', inner_padding_factor)
|
75 |
+
# print(' outer_padding:', outer_padding)
|
76 |
+
# print(' default_square: ', default_square)
|
77 |
+
|
78 |
+
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
|
79 |
+
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
|
80 |
+
|
81 |
+
# 0) make the inner region a square
|
82 |
+
if default_square:
|
83 |
+
size_diff = max(tmp_crop_size) - tmp_crop_size
|
84 |
+
tmp_5pts += size_diff / 2
|
85 |
+
tmp_crop_size += size_diff
|
86 |
+
|
87 |
+
# print('---> default:')
|
88 |
+
# print(' crop_size = ', tmp_crop_size)
|
89 |
+
# print(' reference_5pts = ', tmp_5pts)
|
90 |
+
|
91 |
+
if (output_size and
|
92 |
+
output_size[0] == tmp_crop_size[0] and
|
93 |
+
output_size[1] == tmp_crop_size[1]):
|
94 |
+
# print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
|
95 |
+
return tmp_5pts
|
96 |
+
|
97 |
+
if (inner_padding_factor == 0 and
|
98 |
+
outer_padding == (0, 0)):
|
99 |
+
if output_size is None:
|
100 |
+
# print('No paddings to do: return default reference points')
|
101 |
+
return tmp_5pts
|
102 |
+
else:
|
103 |
+
raise FaceWarpException(
|
104 |
+
'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
|
105 |
+
|
106 |
+
# check output size
|
107 |
+
if not (0 <= inner_padding_factor <= 1.0):
|
108 |
+
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
|
109 |
+
|
110 |
+
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
|
111 |
+
and output_size is None):
|
112 |
+
output_size = tmp_crop_size * \
|
113 |
+
(1 + inner_padding_factor * 2).astype(np.int32)
|
114 |
+
output_size += np.array(outer_padding)
|
115 |
+
# print(' deduced from paddings, output_size = ', output_size)
|
116 |
+
|
117 |
+
if not (outer_padding[0] < output_size[0]
|
118 |
+
and outer_padding[1] < output_size[1]):
|
119 |
+
raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
|
120 |
+
'and outer_padding[1] < output_size[1])')
|
121 |
+
|
122 |
+
# 1) pad the inner region according inner_padding_factor
|
123 |
+
# print('---> STEP1: pad the inner region according inner_padding_factor')
|
124 |
+
if inner_padding_factor > 0:
|
125 |
+
size_diff = tmp_crop_size * inner_padding_factor * 2
|
126 |
+
tmp_5pts += size_diff / 2
|
127 |
+
tmp_crop_size += np.round(size_diff).astype(np.int32)
|
128 |
+
|
129 |
+
# print(' crop_size = ', tmp_crop_size)
|
130 |
+
# print(' reference_5pts = ', tmp_5pts)
|
131 |
+
|
132 |
+
# 2) resize the padded inner region
|
133 |
+
# print('---> STEP2: resize the padded inner region')
|
134 |
+
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
|
135 |
+
# print(' crop_size = ', tmp_crop_size)
|
136 |
+
# print(' size_bf_outer_pad = ', size_bf_outer_pad)
|
137 |
+
|
138 |
+
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
|
139 |
+
raise FaceWarpException('Must have (output_size - outer_padding)'
|
140 |
+
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
|
141 |
+
|
142 |
+
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
|
143 |
+
# print(' resize scale_factor = ', scale_factor)
|
144 |
+
tmp_5pts = tmp_5pts * scale_factor
|
145 |
+
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
|
146 |
+
# tmp_5pts = tmp_5pts + size_diff / 2
|
147 |
+
tmp_crop_size = size_bf_outer_pad
|
148 |
+
# print(' crop_size = ', tmp_crop_size)
|
149 |
+
# print(' reference_5pts = ', tmp_5pts)
|
150 |
+
|
151 |
+
# 3) add outer_padding to make output_size
|
152 |
+
reference_5point = tmp_5pts + np.array(outer_padding)
|
153 |
+
tmp_crop_size = output_size
|
154 |
+
# print('---> STEP3: add outer_padding to make output_size')
|
155 |
+
# print(' crop_size = ', tmp_crop_size)
|
156 |
+
# print(' reference_5pts = ', tmp_5pts)
|
157 |
+
|
158 |
+
# print('===> end get_reference_facial_points\n')
|
159 |
+
|
160 |
+
return reference_5point
|
161 |
+
|
162 |
+
|
163 |
+
def get_affine_transform_matrix(src_pts, dst_pts):
|
164 |
+
"""
|
165 |
+
Function:
|
166 |
+
----------
|
167 |
+
get affine transform matrix 'tfm' from src_pts to dst_pts
|
168 |
+
Parameters:
|
169 |
+
----------
|
170 |
+
@src_pts: Kx2 np.array
|
171 |
+
source points matrix, each row is a pair of coordinates (x, y)
|
172 |
+
@dst_pts: Kx2 np.array
|
173 |
+
destination points matrix, each row is a pair of coordinates (x, y)
|
174 |
+
Returns:
|
175 |
+
----------
|
176 |
+
@tfm: 2x3 np.array
|
177 |
+
transform matrix from src_pts to dst_pts
|
178 |
+
"""
|
179 |
+
|
180 |
+
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
|
181 |
+
n_pts = src_pts.shape[0]
|
182 |
+
ones = np.ones((n_pts, 1), src_pts.dtype)
|
183 |
+
src_pts_ = np.hstack([src_pts, ones])
|
184 |
+
dst_pts_ = np.hstack([dst_pts, ones])
|
185 |
+
|
186 |
+
# #print(('src_pts_:\n' + str(src_pts_))
|
187 |
+
# #print(('dst_pts_:\n' + str(dst_pts_))
|
188 |
+
|
189 |
+
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
|
190 |
+
|
191 |
+
# #print(('np.linalg.lstsq return A: \n' + str(A))
|
192 |
+
# #print(('np.linalg.lstsq return res: \n' + str(res))
|
193 |
+
# #print(('np.linalg.lstsq return rank: \n' + str(rank))
|
194 |
+
# #print(('np.linalg.lstsq return s: \n' + str(s))
|
195 |
+
|
196 |
+
if rank == 3:
|
197 |
+
tfm = np.float32([
|
198 |
+
[A[0, 0], A[1, 0], A[2, 0]],
|
199 |
+
[A[0, 1], A[1, 1], A[2, 1]]
|
200 |
+
])
|
201 |
+
elif rank == 2:
|
202 |
+
tfm = np.float32([
|
203 |
+
[A[0, 0], A[1, 0], 0],
|
204 |
+
[A[0, 1], A[1, 1], 0]
|
205 |
+
])
|
206 |
+
|
207 |
+
return tfm
|
208 |
+
|
209 |
+
|
210 |
+
def warp_and_crop_face(src_img,
|
211 |
+
facial_pts,
|
212 |
+
reference_pts=None,
|
213 |
+
crop_size=(96, 112),
|
214 |
+
align_type='smilarity'):
|
215 |
+
"""
|
216 |
+
Function:
|
217 |
+
----------
|
218 |
+
apply affine transform 'trans' to uv
|
219 |
+
Parameters:
|
220 |
+
----------
|
221 |
+
@src_img: 3x3 np.array
|
222 |
+
input image
|
223 |
+
@facial_pts: could be
|
224 |
+
1)a list of K coordinates (x,y)
|
225 |
+
or
|
226 |
+
2) Kx2 or 2xK np.array
|
227 |
+
each row or col is a pair of coordinates (x, y)
|
228 |
+
@reference_pts: could be
|
229 |
+
1) a list of K coordinates (x,y)
|
230 |
+
or
|
231 |
+
2) Kx2 or 2xK np.array
|
232 |
+
each row or col is a pair of coordinates (x, y)
|
233 |
+
or
|
234 |
+
3) None
|
235 |
+
if None, use default reference facial points
|
236 |
+
@crop_size: (w, h)
|
237 |
+
output face image size
|
238 |
+
@align_type: transform type, could be one of
|
239 |
+
1) 'similarity': use similarity transform
|
240 |
+
2) 'cv2_affine': use the first 3 points to do affine transform,
|
241 |
+
by calling cv2.getAffineTransform()
|
242 |
+
3) 'affine': use all points to do affine transform
|
243 |
+
Returns:
|
244 |
+
----------
|
245 |
+
@face_img: output face image with size (w, h) = @crop_size
|
246 |
+
"""
|
247 |
+
|
248 |
+
if reference_pts is None:
|
249 |
+
if crop_size[0] == 96 and crop_size[1] == 112:
|
250 |
+
reference_pts = REFERENCE_FACIAL_POINTS
|
251 |
+
else:
|
252 |
+
default_square = False
|
253 |
+
inner_padding_factor = 0
|
254 |
+
outer_padding = (0, 0)
|
255 |
+
output_size = crop_size
|
256 |
+
|
257 |
+
reference_pts = get_reference_facial_points(output_size,
|
258 |
+
inner_padding_factor,
|
259 |
+
outer_padding,
|
260 |
+
default_square)
|
261 |
+
|
262 |
+
ref_pts = np.float32(reference_pts)
|
263 |
+
ref_pts_shp = ref_pts.shape
|
264 |
+
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
|
265 |
+
raise FaceWarpException(
|
266 |
+
'reference_pts.shape must be (K,2) or (2,K) and K>2')
|
267 |
+
|
268 |
+
if ref_pts_shp[0] == 2:
|
269 |
+
ref_pts = ref_pts.T
|
270 |
+
|
271 |
+
src_pts = np.float32(facial_pts)
|
272 |
+
src_pts_shp = src_pts.shape
|
273 |
+
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
|
274 |
+
raise FaceWarpException(
|
275 |
+
'facial_pts.shape must be (K,2) or (2,K) and K>2')
|
276 |
+
|
277 |
+
if src_pts_shp[0] == 2:
|
278 |
+
src_pts = src_pts.T
|
279 |
+
|
280 |
+
# #print('--->src_pts:\n', src_pts
|
281 |
+
# #print('--->ref_pts\n', ref_pts
|
282 |
+
|
283 |
+
if src_pts.shape != ref_pts.shape:
|
284 |
+
raise FaceWarpException(
|
285 |
+
'facial_pts and reference_pts must have the same shape')
|
286 |
+
|
287 |
+
if align_type is 'cv2_affine':
|
288 |
+
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
|
289 |
+
# #print(('cv2.getAffineTransform() returns tfm=\n' + str(tfm))
|
290 |
+
elif align_type is 'affine':
|
291 |
+
tfm = get_affine_transform_matrix(src_pts, ref_pts)
|
292 |
+
# #print(('get_affine_transform_matrix() returns tfm=\n' + str(tfm))
|
293 |
+
else:
|
294 |
+
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
|
295 |
+
# #print(('get_similarity_transform_for_cv2() returns tfm=\n' + str(tfm))
|
296 |
+
|
297 |
+
# #print('--->Transform matrix: '
|
298 |
+
# #print(('type(tfm):' + str(type(tfm)))
|
299 |
+
# #print(('tfm.dtype:' + str(tfm.dtype))
|
300 |
+
# #print( tfm
|
301 |
+
|
302 |
+
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
|
303 |
+
|
304 |
+
return face_img, tfm
|
models/mtcnn/mtcnn_pytorch/src/box_utils.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
def nms(boxes, overlap_threshold=0.5, mode='union'):
|
6 |
+
"""Non-maximum suppression.
|
7 |
+
|
8 |
+
Arguments:
|
9 |
+
boxes: a float numpy array of shape [n, 5],
|
10 |
+
where each row is (xmin, ymin, xmax, ymax, score).
|
11 |
+
overlap_threshold: a float number.
|
12 |
+
mode: 'union' or 'min'.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
list with indices of the selected boxes
|
16 |
+
"""
|
17 |
+
|
18 |
+
# if there are no boxes, return the empty list
|
19 |
+
if len(boxes) == 0:
|
20 |
+
return []
|
21 |
+
|
22 |
+
# list of picked indices
|
23 |
+
pick = []
|
24 |
+
|
25 |
+
# grab the coordinates of the bounding boxes
|
26 |
+
x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)]
|
27 |
+
|
28 |
+
area = (x2 - x1 + 1.0) * (y2 - y1 + 1.0)
|
29 |
+
ids = np.argsort(score) # in increasing order
|
30 |
+
|
31 |
+
while len(ids) > 0:
|
32 |
+
|
33 |
+
# grab index of the largest value
|
34 |
+
last = len(ids) - 1
|
35 |
+
i = ids[last]
|
36 |
+
pick.append(i)
|
37 |
+
|
38 |
+
# compute intersections
|
39 |
+
# of the box with the largest score
|
40 |
+
# with the rest of boxes
|
41 |
+
|
42 |
+
# left top corner of intersection boxes
|
43 |
+
ix1 = np.maximum(x1[i], x1[ids[:last]])
|
44 |
+
iy1 = np.maximum(y1[i], y1[ids[:last]])
|
45 |
+
|
46 |
+
# right bottom corner of intersection boxes
|
47 |
+
ix2 = np.minimum(x2[i], x2[ids[:last]])
|
48 |
+
iy2 = np.minimum(y2[i], y2[ids[:last]])
|
49 |
+
|
50 |
+
# width and height of intersection boxes
|
51 |
+
w = np.maximum(0.0, ix2 - ix1 + 1.0)
|
52 |
+
h = np.maximum(0.0, iy2 - iy1 + 1.0)
|
53 |
+
|
54 |
+
# intersections' areas
|
55 |
+
inter = w * h
|
56 |
+
if mode == 'min':
|
57 |
+
overlap = inter / np.minimum(area[i], area[ids[:last]])
|
58 |
+
elif mode == 'union':
|
59 |
+
# intersection over union (IoU)
|
60 |
+
overlap = inter / (area[i] + area[ids[:last]] - inter)
|
61 |
+
|
62 |
+
# delete all boxes where overlap is too big
|
63 |
+
ids = np.delete(
|
64 |
+
ids,
|
65 |
+
np.concatenate([[last], np.where(overlap > overlap_threshold)[0]])
|
66 |
+
)
|
67 |
+
|
68 |
+
return pick
|
69 |
+
|
70 |
+
|
71 |
+
def convert_to_square(bboxes):
|
72 |
+
"""Convert bounding boxes to a square form.
|
73 |
+
|
74 |
+
Arguments:
|
75 |
+
bboxes: a float numpy array of shape [n, 5].
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
a float numpy array of shape [n, 5],
|
79 |
+
squared bounding boxes.
|
80 |
+
"""
|
81 |
+
|
82 |
+
square_bboxes = np.zeros_like(bboxes)
|
83 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
84 |
+
h = y2 - y1 + 1.0
|
85 |
+
w = x2 - x1 + 1.0
|
86 |
+
max_side = np.maximum(h, w)
|
87 |
+
square_bboxes[:, 0] = x1 + w * 0.5 - max_side * 0.5
|
88 |
+
square_bboxes[:, 1] = y1 + h * 0.5 - max_side * 0.5
|
89 |
+
square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0
|
90 |
+
square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0
|
91 |
+
return square_bboxes
|
92 |
+
|
93 |
+
|
94 |
+
def calibrate_box(bboxes, offsets):
|
95 |
+
"""Transform bounding boxes to be more like true bounding boxes.
|
96 |
+
'offsets' is one of the outputs of the nets.
|
97 |
+
|
98 |
+
Arguments:
|
99 |
+
bboxes: a float numpy array of shape [n, 5].
|
100 |
+
offsets: a float numpy array of shape [n, 4].
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
a float numpy array of shape [n, 5].
|
104 |
+
"""
|
105 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
106 |
+
w = x2 - x1 + 1.0
|
107 |
+
h = y2 - y1 + 1.0
|
108 |
+
w = np.expand_dims(w, 1)
|
109 |
+
h = np.expand_dims(h, 1)
|
110 |
+
|
111 |
+
# this is what happening here:
|
112 |
+
# tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)]
|
113 |
+
# x1_true = x1 + tx1*w
|
114 |
+
# y1_true = y1 + ty1*h
|
115 |
+
# x2_true = x2 + tx2*w
|
116 |
+
# y2_true = y2 + ty2*h
|
117 |
+
# below is just more compact form of this
|
118 |
+
|
119 |
+
# are offsets always such that
|
120 |
+
# x1 < x2 and y1 < y2 ?
|
121 |
+
|
122 |
+
translation = np.hstack([w, h, w, h]) * offsets
|
123 |
+
bboxes[:, 0:4] = bboxes[:, 0:4] + translation
|
124 |
+
return bboxes
|
125 |
+
|
126 |
+
|
127 |
+
def get_image_boxes(bounding_boxes, img, size=24):
|
128 |
+
"""Cut out boxes from the image.
|
129 |
+
|
130 |
+
Arguments:
|
131 |
+
bounding_boxes: a float numpy array of shape [n, 5].
|
132 |
+
img: an instance of PIL.Image.
|
133 |
+
size: an integer, size of cutouts.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
a float numpy array of shape [n, 3, size, size].
|
137 |
+
"""
|
138 |
+
|
139 |
+
num_boxes = len(bounding_boxes)
|
140 |
+
width, height = img.size
|
141 |
+
|
142 |
+
[dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height)
|
143 |
+
img_boxes = np.zeros((num_boxes, 3, size, size), 'float32')
|
144 |
+
|
145 |
+
for i in range(num_boxes):
|
146 |
+
img_box = np.zeros((h[i], w[i], 3), 'uint8')
|
147 |
+
|
148 |
+
img_array = np.asarray(img, 'uint8')
|
149 |
+
img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] = \
|
150 |
+
img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :]
|
151 |
+
|
152 |
+
# resize
|
153 |
+
img_box = Image.fromarray(img_box)
|
154 |
+
img_box = img_box.resize((size, size), Image.BILINEAR)
|
155 |
+
img_box = np.asarray(img_box, 'float32')
|
156 |
+
|
157 |
+
img_boxes[i, :, :, :] = _preprocess(img_box)
|
158 |
+
|
159 |
+
return img_boxes
|
160 |
+
|
161 |
+
|
162 |
+
def correct_bboxes(bboxes, width, height):
|
163 |
+
"""Crop boxes that are too big and get coordinates
|
164 |
+
with respect to cutouts.
|
165 |
+
|
166 |
+
Arguments:
|
167 |
+
bboxes: a float numpy array of shape [n, 5],
|
168 |
+
where each row is (xmin, ymin, xmax, ymax, score).
|
169 |
+
width: a float number.
|
170 |
+
height: a float number.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
dy, dx, edy, edx: a int numpy arrays of shape [n],
|
174 |
+
coordinates of the boxes with respect to the cutouts.
|
175 |
+
y, x, ey, ex: a int numpy arrays of shape [n],
|
176 |
+
corrected ymin, xmin, ymax, xmax.
|
177 |
+
h, w: a int numpy arrays of shape [n],
|
178 |
+
just heights and widths of boxes.
|
179 |
+
|
180 |
+
in the following order:
|
181 |
+
[dy, edy, dx, edx, y, ey, x, ex, w, h].
|
182 |
+
"""
|
183 |
+
|
184 |
+
x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)]
|
185 |
+
w, h = x2 - x1 + 1.0, y2 - y1 + 1.0
|
186 |
+
num_boxes = bboxes.shape[0]
|
187 |
+
|
188 |
+
# 'e' stands for end
|
189 |
+
# (x, y) -> (ex, ey)
|
190 |
+
x, y, ex, ey = x1, y1, x2, y2
|
191 |
+
|
192 |
+
# we need to cut out a box from the image.
|
193 |
+
# (x, y, ex, ey) are corrected coordinates of the box
|
194 |
+
# in the image.
|
195 |
+
# (dx, dy, edx, edy) are coordinates of the box in the cutout
|
196 |
+
# from the image.
|
197 |
+
dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,))
|
198 |
+
edx, edy = w.copy() - 1.0, h.copy() - 1.0
|
199 |
+
|
200 |
+
# if box's bottom right corner is too far right
|
201 |
+
ind = np.where(ex > width - 1.0)[0]
|
202 |
+
edx[ind] = w[ind] + width - 2.0 - ex[ind]
|
203 |
+
ex[ind] = width - 1.0
|
204 |
+
|
205 |
+
# if box's bottom right corner is too low
|
206 |
+
ind = np.where(ey > height - 1.0)[0]
|
207 |
+
edy[ind] = h[ind] + height - 2.0 - ey[ind]
|
208 |
+
ey[ind] = height - 1.0
|
209 |
+
|
210 |
+
# if box's top left corner is too far left
|
211 |
+
ind = np.where(x < 0.0)[0]
|
212 |
+
dx[ind] = 0.0 - x[ind]
|
213 |
+
x[ind] = 0.0
|
214 |
+
|
215 |
+
# if box's top left corner is too high
|
216 |
+
ind = np.where(y < 0.0)[0]
|
217 |
+
dy[ind] = 0.0 - y[ind]
|
218 |
+
y[ind] = 0.0
|
219 |
+
|
220 |
+
return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h]
|
221 |
+
return_list = [i.astype('int32') for i in return_list]
|
222 |
+
|
223 |
+
return return_list
|
224 |
+
|
225 |
+
|
226 |
+
def _preprocess(img):
|
227 |
+
"""Preprocessing step before feeding the network.
|
228 |
+
|
229 |
+
Arguments:
|
230 |
+
img: a float numpy array of shape [h, w, c].
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
a float numpy array of shape [1, c, h, w].
|
234 |
+
"""
|
235 |
+
img = img.transpose((2, 0, 1))
|
236 |
+
img = np.expand_dims(img, 0)
|
237 |
+
img = (img - 127.5) * 0.0078125
|
238 |
+
return img
|
models/mtcnn/mtcnn_pytorch/src/detector.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from .get_nets import PNet, RNet, ONet
|
5 |
+
from .box_utils import nms, calibrate_box, get_image_boxes, convert_to_square
|
6 |
+
from .first_stage import run_first_stage
|
7 |
+
|
8 |
+
|
9 |
+
def detect_faces(image, min_face_size=20.0,
|
10 |
+
thresholds=[0.6, 0.7, 0.8],
|
11 |
+
nms_thresholds=[0.7, 0.7, 0.7]):
|
12 |
+
"""
|
13 |
+
Arguments:
|
14 |
+
image: an instance of PIL.Image.
|
15 |
+
min_face_size: a float number.
|
16 |
+
thresholds: a list of length 3.
|
17 |
+
nms_thresholds: a list of length 3.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10],
|
21 |
+
bounding boxes and facial landmarks.
|
22 |
+
"""
|
23 |
+
|
24 |
+
# LOAD MODELS
|
25 |
+
pnet = PNet()
|
26 |
+
rnet = RNet()
|
27 |
+
onet = ONet()
|
28 |
+
onet.eval()
|
29 |
+
|
30 |
+
# BUILD AN IMAGE PYRAMID
|
31 |
+
width, height = image.size
|
32 |
+
min_length = min(height, width)
|
33 |
+
|
34 |
+
min_detection_size = 12
|
35 |
+
factor = 0.707 # sqrt(0.5)
|
36 |
+
|
37 |
+
# scales for scaling the image
|
38 |
+
scales = []
|
39 |
+
|
40 |
+
# scales the image so that
|
41 |
+
# minimum size that we can detect equals to
|
42 |
+
# minimum face size that we want to detect
|
43 |
+
m = min_detection_size / min_face_size
|
44 |
+
min_length *= m
|
45 |
+
|
46 |
+
factor_count = 0
|
47 |
+
while min_length > min_detection_size:
|
48 |
+
scales.append(m * factor ** factor_count)
|
49 |
+
min_length *= factor
|
50 |
+
factor_count += 1
|
51 |
+
|
52 |
+
# STAGE 1
|
53 |
+
|
54 |
+
# it will be returned
|
55 |
+
bounding_boxes = []
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
# run P-Net on different scales
|
59 |
+
for s in scales:
|
60 |
+
boxes = run_first_stage(image, pnet, scale=s, threshold=thresholds[0])
|
61 |
+
bounding_boxes.append(boxes)
|
62 |
+
|
63 |
+
# collect boxes (and offsets, and scores) from different scales
|
64 |
+
bounding_boxes = [i for i in bounding_boxes if i is not None]
|
65 |
+
bounding_boxes = np.vstack(bounding_boxes)
|
66 |
+
|
67 |
+
keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0])
|
68 |
+
bounding_boxes = bounding_boxes[keep]
|
69 |
+
|
70 |
+
# use offsets predicted by pnet to transform bounding boxes
|
71 |
+
bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:])
|
72 |
+
# shape [n_boxes, 5]
|
73 |
+
|
74 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
75 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
76 |
+
|
77 |
+
# STAGE 2
|
78 |
+
|
79 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=24)
|
80 |
+
img_boxes = torch.FloatTensor(img_boxes)
|
81 |
+
|
82 |
+
output = rnet(img_boxes)
|
83 |
+
offsets = output[0].data.numpy() # shape [n_boxes, 4]
|
84 |
+
probs = output[1].data.numpy() # shape [n_boxes, 2]
|
85 |
+
|
86 |
+
keep = np.where(probs[:, 1] > thresholds[1])[0]
|
87 |
+
bounding_boxes = bounding_boxes[keep]
|
88 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
89 |
+
offsets = offsets[keep]
|
90 |
+
|
91 |
+
keep = nms(bounding_boxes, nms_thresholds[1])
|
92 |
+
bounding_boxes = bounding_boxes[keep]
|
93 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets[keep])
|
94 |
+
bounding_boxes = convert_to_square(bounding_boxes)
|
95 |
+
bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4])
|
96 |
+
|
97 |
+
# STAGE 3
|
98 |
+
|
99 |
+
img_boxes = get_image_boxes(bounding_boxes, image, size=48)
|
100 |
+
if len(img_boxes) == 0:
|
101 |
+
return [], []
|
102 |
+
img_boxes = torch.FloatTensor(img_boxes)
|
103 |
+
output = onet(img_boxes)
|
104 |
+
landmarks = output[0].data.numpy() # shape [n_boxes, 10]
|
105 |
+
offsets = output[1].data.numpy() # shape [n_boxes, 4]
|
106 |
+
probs = output[2].data.numpy() # shape [n_boxes, 2]
|
107 |
+
|
108 |
+
keep = np.where(probs[:, 1] > thresholds[2])[0]
|
109 |
+
bounding_boxes = bounding_boxes[keep]
|
110 |
+
bounding_boxes[:, 4] = probs[keep, 1].reshape((-1,))
|
111 |
+
offsets = offsets[keep]
|
112 |
+
landmarks = landmarks[keep]
|
113 |
+
|
114 |
+
# compute landmark points
|
115 |
+
width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0
|
116 |
+
height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0
|
117 |
+
xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1]
|
118 |
+
landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1) * landmarks[:, 0:5]
|
119 |
+
landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1) * landmarks[:, 5:10]
|
120 |
+
|
121 |
+
bounding_boxes = calibrate_box(bounding_boxes, offsets)
|
122 |
+
keep = nms(bounding_boxes, nms_thresholds[2], mode='min')
|
123 |
+
bounding_boxes = bounding_boxes[keep]
|
124 |
+
landmarks = landmarks[keep]
|
125 |
+
|
126 |
+
return bounding_boxes, landmarks
|
models/mtcnn/mtcnn_pytorch/src/first_stage.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from .box_utils import nms, _preprocess
|
7 |
+
|
8 |
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
9 |
+
device = 'cuda:0'
|
10 |
+
|
11 |
+
|
12 |
+
def run_first_stage(image, net, scale, threshold):
|
13 |
+
"""Run P-Net, generate bounding boxes, and do NMS.
|
14 |
+
|
15 |
+
Arguments:
|
16 |
+
image: an instance of PIL.Image.
|
17 |
+
net: an instance of pytorch's nn.Module, P-Net.
|
18 |
+
scale: a float number,
|
19 |
+
scale width and height of the image by this number.
|
20 |
+
threshold: a float number,
|
21 |
+
threshold on the probability of a face when generating
|
22 |
+
bounding boxes from predictions of the net.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
a float numpy array of shape [n_boxes, 9],
|
26 |
+
bounding boxes with scores and offsets (4 + 1 + 4).
|
27 |
+
"""
|
28 |
+
|
29 |
+
# scale the image and convert it to a float array
|
30 |
+
width, height = image.size
|
31 |
+
sw, sh = math.ceil(width * scale), math.ceil(height * scale)
|
32 |
+
img = image.resize((sw, sh), Image.BILINEAR)
|
33 |
+
img = np.asarray(img, 'float32')
|
34 |
+
|
35 |
+
img = torch.FloatTensor(_preprocess(img)).to(device)
|
36 |
+
with torch.no_grad():
|
37 |
+
output = net(img)
|
38 |
+
probs = output[1].cpu().data.numpy()[0, 1, :, :]
|
39 |
+
offsets = output[0].cpu().data.numpy()
|
40 |
+
# probs: probability of a face at each sliding window
|
41 |
+
# offsets: transformations to true bounding boxes
|
42 |
+
|
43 |
+
boxes = _generate_bboxes(probs, offsets, scale, threshold)
|
44 |
+
if len(boxes) == 0:
|
45 |
+
return None
|
46 |
+
|
47 |
+
keep = nms(boxes[:, 0:5], overlap_threshold=0.5)
|
48 |
+
return boxes[keep]
|
49 |
+
|
50 |
+
|
51 |
+
def _generate_bboxes(probs, offsets, scale, threshold):
|
52 |
+
"""Generate bounding boxes at places
|
53 |
+
where there is probably a face.
|
54 |
+
|
55 |
+
Arguments:
|
56 |
+
probs: a float numpy array of shape [n, m].
|
57 |
+
offsets: a float numpy array of shape [1, 4, n, m].
|
58 |
+
scale: a float number,
|
59 |
+
width and height of the image were scaled by this number.
|
60 |
+
threshold: a float number.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
a float numpy array of shape [n_boxes, 9]
|
64 |
+
"""
|
65 |
+
|
66 |
+
# applying P-Net is equivalent, in some sense, to
|
67 |
+
# moving 12x12 window with stride 2
|
68 |
+
stride = 2
|
69 |
+
cell_size = 12
|
70 |
+
|
71 |
+
# indices of boxes where there is probably a face
|
72 |
+
inds = np.where(probs > threshold)
|
73 |
+
|
74 |
+
if inds[0].size == 0:
|
75 |
+
return np.array([])
|
76 |
+
|
77 |
+
# transformations of bounding boxes
|
78 |
+
tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)]
|
79 |
+
# they are defined as:
|
80 |
+
# w = x2 - x1 + 1
|
81 |
+
# h = y2 - y1 + 1
|
82 |
+
# x1_true = x1 + tx1*w
|
83 |
+
# x2_true = x2 + tx2*w
|
84 |
+
# y1_true = y1 + ty1*h
|
85 |
+
# y2_true = y2 + ty2*h
|
86 |
+
|
87 |
+
offsets = np.array([tx1, ty1, tx2, ty2])
|
88 |
+
score = probs[inds[0], inds[1]]
|
89 |
+
|
90 |
+
# P-Net is applied to scaled images
|
91 |
+
# so we need to rescale bounding boxes back
|
92 |
+
bounding_boxes = np.vstack([
|
93 |
+
np.round((stride * inds[1] + 1.0) / scale),
|
94 |
+
np.round((stride * inds[0] + 1.0) / scale),
|
95 |
+
np.round((stride * inds[1] + 1.0 + cell_size) / scale),
|
96 |
+
np.round((stride * inds[0] + 1.0 + cell_size) / scale),
|
97 |
+
score, offsets
|
98 |
+
])
|
99 |
+
# why one is added?
|
100 |
+
|
101 |
+
return bounding_boxes.T
|
models/mtcnn/mtcnn_pytorch/src/get_nets.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from collections import OrderedDict
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from configs.paths_config import model_paths
|
8 |
+
PNET_PATH = model_paths["mtcnn_pnet"]
|
9 |
+
ONET_PATH = model_paths["mtcnn_onet"]
|
10 |
+
RNET_PATH = model_paths["mtcnn_rnet"]
|
11 |
+
|
12 |
+
|
13 |
+
class Flatten(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super(Flatten, self).__init__()
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
"""
|
20 |
+
Arguments:
|
21 |
+
x: a float tensor with shape [batch_size, c, h, w].
|
22 |
+
Returns:
|
23 |
+
a float tensor with shape [batch_size, c*h*w].
|
24 |
+
"""
|
25 |
+
|
26 |
+
# without this pretrained model isn't working
|
27 |
+
x = x.transpose(3, 2).contiguous()
|
28 |
+
|
29 |
+
return x.view(x.size(0), -1)
|
30 |
+
|
31 |
+
|
32 |
+
class PNet(nn.Module):
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
# suppose we have input with size HxW, then
|
38 |
+
# after first layer: H - 2,
|
39 |
+
# after pool: ceil((H - 2)/2),
|
40 |
+
# after second conv: ceil((H - 2)/2) - 2,
|
41 |
+
# after last conv: ceil((H - 2)/2) - 4,
|
42 |
+
# and the same for W
|
43 |
+
|
44 |
+
self.features = nn.Sequential(OrderedDict([
|
45 |
+
('conv1', nn.Conv2d(3, 10, 3, 1)),
|
46 |
+
('prelu1', nn.PReLU(10)),
|
47 |
+
('pool1', nn.MaxPool2d(2, 2, ceil_mode=True)),
|
48 |
+
|
49 |
+
('conv2', nn.Conv2d(10, 16, 3, 1)),
|
50 |
+
('prelu2', nn.PReLU(16)),
|
51 |
+
|
52 |
+
('conv3', nn.Conv2d(16, 32, 3, 1)),
|
53 |
+
('prelu3', nn.PReLU(32))
|
54 |
+
]))
|
55 |
+
|
56 |
+
self.conv4_1 = nn.Conv2d(32, 2, 1, 1)
|
57 |
+
self.conv4_2 = nn.Conv2d(32, 4, 1, 1)
|
58 |
+
|
59 |
+
weights = np.load(PNET_PATH, allow_pickle=True)[()]
|
60 |
+
for n, p in self.named_parameters():
|
61 |
+
p.data = torch.FloatTensor(weights[n])
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
"""
|
65 |
+
Arguments:
|
66 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
67 |
+
Returns:
|
68 |
+
b: a float tensor with shape [batch_size, 4, h', w'].
|
69 |
+
a: a float tensor with shape [batch_size, 2, h', w'].
|
70 |
+
"""
|
71 |
+
x = self.features(x)
|
72 |
+
a = self.conv4_1(x)
|
73 |
+
b = self.conv4_2(x)
|
74 |
+
a = F.softmax(a, dim=-1)
|
75 |
+
return b, a
|
76 |
+
|
77 |
+
|
78 |
+
class RNet(nn.Module):
|
79 |
+
|
80 |
+
def __init__(self):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.features = nn.Sequential(OrderedDict([
|
84 |
+
('conv1', nn.Conv2d(3, 28, 3, 1)),
|
85 |
+
('prelu1', nn.PReLU(28)),
|
86 |
+
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
87 |
+
|
88 |
+
('conv2', nn.Conv2d(28, 48, 3, 1)),
|
89 |
+
('prelu2', nn.PReLU(48)),
|
90 |
+
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
91 |
+
|
92 |
+
('conv3', nn.Conv2d(48, 64, 2, 1)),
|
93 |
+
('prelu3', nn.PReLU(64)),
|
94 |
+
|
95 |
+
('flatten', Flatten()),
|
96 |
+
('conv4', nn.Linear(576, 128)),
|
97 |
+
('prelu4', nn.PReLU(128))
|
98 |
+
]))
|
99 |
+
|
100 |
+
self.conv5_1 = nn.Linear(128, 2)
|
101 |
+
self.conv5_2 = nn.Linear(128, 4)
|
102 |
+
|
103 |
+
weights = np.load(RNET_PATH, allow_pickle=True)[()]
|
104 |
+
for n, p in self.named_parameters():
|
105 |
+
p.data = torch.FloatTensor(weights[n])
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
"""
|
109 |
+
Arguments:
|
110 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
111 |
+
Returns:
|
112 |
+
b: a float tensor with shape [batch_size, 4].
|
113 |
+
a: a float tensor with shape [batch_size, 2].
|
114 |
+
"""
|
115 |
+
x = self.features(x)
|
116 |
+
a = self.conv5_1(x)
|
117 |
+
b = self.conv5_2(x)
|
118 |
+
a = F.softmax(a, dim=-1)
|
119 |
+
return b, a
|
120 |
+
|
121 |
+
|
122 |
+
class ONet(nn.Module):
|
123 |
+
|
124 |
+
def __init__(self):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
self.features = nn.Sequential(OrderedDict([
|
128 |
+
('conv1', nn.Conv2d(3, 32, 3, 1)),
|
129 |
+
('prelu1', nn.PReLU(32)),
|
130 |
+
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
131 |
+
|
132 |
+
('conv2', nn.Conv2d(32, 64, 3, 1)),
|
133 |
+
('prelu2', nn.PReLU(64)),
|
134 |
+
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
|
135 |
+
|
136 |
+
('conv3', nn.Conv2d(64, 64, 3, 1)),
|
137 |
+
('prelu3', nn.PReLU(64)),
|
138 |
+
('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
|
139 |
+
|
140 |
+
('conv4', nn.Conv2d(64, 128, 2, 1)),
|
141 |
+
('prelu4', nn.PReLU(128)),
|
142 |
+
|
143 |
+
('flatten', Flatten()),
|
144 |
+
('conv5', nn.Linear(1152, 256)),
|
145 |
+
('drop5', nn.Dropout(0.25)),
|
146 |
+
('prelu5', nn.PReLU(256)),
|
147 |
+
]))
|
148 |
+
|
149 |
+
self.conv6_1 = nn.Linear(256, 2)
|
150 |
+
self.conv6_2 = nn.Linear(256, 4)
|
151 |
+
self.conv6_3 = nn.Linear(256, 10)
|
152 |
+
|
153 |
+
weights = np.load(ONET_PATH, allow_pickle=True)[()]
|
154 |
+
for n, p in self.named_parameters():
|
155 |
+
p.data = torch.FloatTensor(weights[n])
|
156 |
+
|
157 |
+
def forward(self, x):
|
158 |
+
"""
|
159 |
+
Arguments:
|
160 |
+
x: a float tensor with shape [batch_size, 3, h, w].
|
161 |
+
Returns:
|
162 |
+
c: a float tensor with shape [batch_size, 10].
|
163 |
+
b: a float tensor with shape [batch_size, 4].
|
164 |
+
a: a float tensor with shape [batch_size, 2].
|
165 |
+
"""
|
166 |
+
x = self.features(x)
|
167 |
+
a = self.conv6_1(x)
|
168 |
+
b = self.conv6_2(x)
|
169 |
+
c = self.conv6_3(x)
|
170 |
+
a = F.softmax(a, dim=-1)
|
171 |
+
return c, b, a
|
models/mtcnn/mtcnn_pytorch/src/matlab_cp2tform.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Tue Jul 11 06:54:28 2017
|
4 |
+
|
5 |
+
@author: zhaoyafei
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from numpy.linalg import inv, norm, lstsq
|
10 |
+
from numpy.linalg import matrix_rank as rank
|
11 |
+
|
12 |
+
|
13 |
+
class MatlabCp2tormException(Exception):
|
14 |
+
def __str__(self):
|
15 |
+
return 'In File {}:{}'.format(
|
16 |
+
__file__, super.__str__(self))
|
17 |
+
|
18 |
+
|
19 |
+
def tformfwd(trans, uv):
|
20 |
+
"""
|
21 |
+
Function:
|
22 |
+
----------
|
23 |
+
apply affine transform 'trans' to uv
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
----------
|
27 |
+
@trans: 3x3 np.array
|
28 |
+
transform matrix
|
29 |
+
@uv: Kx2 np.array
|
30 |
+
each row is a pair of coordinates (x, y)
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
----------
|
34 |
+
@xy: Kx2 np.array
|
35 |
+
each row is a pair of transformed coordinates (x, y)
|
36 |
+
"""
|
37 |
+
uv = np.hstack((
|
38 |
+
uv, np.ones((uv.shape[0], 1))
|
39 |
+
))
|
40 |
+
xy = np.dot(uv, trans)
|
41 |
+
xy = xy[:, 0:-1]
|
42 |
+
return xy
|
43 |
+
|
44 |
+
|
45 |
+
def tforminv(trans, uv):
|
46 |
+
"""
|
47 |
+
Function:
|
48 |
+
----------
|
49 |
+
apply the inverse of affine transform 'trans' to uv
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
----------
|
53 |
+
@trans: 3x3 np.array
|
54 |
+
transform matrix
|
55 |
+
@uv: Kx2 np.array
|
56 |
+
each row is a pair of coordinates (x, y)
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
----------
|
60 |
+
@xy: Kx2 np.array
|
61 |
+
each row is a pair of inverse-transformed coordinates (x, y)
|
62 |
+
"""
|
63 |
+
Tinv = inv(trans)
|
64 |
+
xy = tformfwd(Tinv, uv)
|
65 |
+
return xy
|
66 |
+
|
67 |
+
|
68 |
+
def findNonreflectiveSimilarity(uv, xy, options=None):
|
69 |
+
options = {'K': 2}
|
70 |
+
|
71 |
+
K = options['K']
|
72 |
+
M = xy.shape[0]
|
73 |
+
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
74 |
+
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
75 |
+
# print('--->x, y:\n', x, y
|
76 |
+
|
77 |
+
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
|
78 |
+
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
|
79 |
+
X = np.vstack((tmp1, tmp2))
|
80 |
+
# print('--->X.shape: ', X.shape
|
81 |
+
# print('X:\n', X
|
82 |
+
|
83 |
+
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
|
84 |
+
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
|
85 |
+
U = np.vstack((u, v))
|
86 |
+
# print('--->U.shape: ', U.shape
|
87 |
+
# print('U:\n', U
|
88 |
+
|
89 |
+
# We know that X * r = U
|
90 |
+
if rank(X) >= 2 * K:
|
91 |
+
r, _, _, _ = lstsq(X, U, rcond=None) # Make sure this is what I want
|
92 |
+
r = np.squeeze(r)
|
93 |
+
else:
|
94 |
+
raise Exception('cp2tform:twoUniquePointsReq')
|
95 |
+
|
96 |
+
# print('--->r:\n', r
|
97 |
+
|
98 |
+
sc = r[0]
|
99 |
+
ss = r[1]
|
100 |
+
tx = r[2]
|
101 |
+
ty = r[3]
|
102 |
+
|
103 |
+
Tinv = np.array([
|
104 |
+
[sc, -ss, 0],
|
105 |
+
[ss, sc, 0],
|
106 |
+
[tx, ty, 1]
|
107 |
+
])
|
108 |
+
|
109 |
+
# print('--->Tinv:\n', Tinv
|
110 |
+
|
111 |
+
T = inv(Tinv)
|
112 |
+
# print('--->T:\n', T
|
113 |
+
|
114 |
+
T[:, 2] = np.array([0, 0, 1])
|
115 |
+
|
116 |
+
return T, Tinv
|
117 |
+
|
118 |
+
|
119 |
+
def findSimilarity(uv, xy, options=None):
|
120 |
+
options = {'K': 2}
|
121 |
+
|
122 |
+
# uv = np.array(uv)
|
123 |
+
# xy = np.array(xy)
|
124 |
+
|
125 |
+
# Solve for trans1
|
126 |
+
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
|
127 |
+
|
128 |
+
# Solve for trans2
|
129 |
+
|
130 |
+
# manually reflect the xy data across the Y-axis
|
131 |
+
xyR = xy
|
132 |
+
xyR[:, 0] = -1 * xyR[:, 0]
|
133 |
+
|
134 |
+
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
|
135 |
+
|
136 |
+
# manually reflect the tform to undo the reflection done on xyR
|
137 |
+
TreflectY = np.array([
|
138 |
+
[-1, 0, 0],
|
139 |
+
[0, 1, 0],
|
140 |
+
[0, 0, 1]
|
141 |
+
])
|
142 |
+
|
143 |
+
trans2 = np.dot(trans2r, TreflectY)
|
144 |
+
|
145 |
+
# Figure out if trans1 or trans2 is better
|
146 |
+
xy1 = tformfwd(trans1, uv)
|
147 |
+
norm1 = norm(xy1 - xy)
|
148 |
+
|
149 |
+
xy2 = tformfwd(trans2, uv)
|
150 |
+
norm2 = norm(xy2 - xy)
|
151 |
+
|
152 |
+
if norm1 <= norm2:
|
153 |
+
return trans1, trans1_inv
|
154 |
+
else:
|
155 |
+
trans2_inv = inv(trans2)
|
156 |
+
return trans2, trans2_inv
|
157 |
+
|
158 |
+
|
159 |
+
def get_similarity_transform(src_pts, dst_pts, reflective=True):
|
160 |
+
"""
|
161 |
+
Function:
|
162 |
+
----------
|
163 |
+
Find Similarity Transform Matrix 'trans':
|
164 |
+
u = src_pts[:, 0]
|
165 |
+
v = src_pts[:, 1]
|
166 |
+
x = dst_pts[:, 0]
|
167 |
+
y = dst_pts[:, 1]
|
168 |
+
[x, y, 1] = [u, v, 1] * trans
|
169 |
+
|
170 |
+
Parameters:
|
171 |
+
----------
|
172 |
+
@src_pts: Kx2 np.array
|
173 |
+
source points, each row is a pair of coordinates (x, y)
|
174 |
+
@dst_pts: Kx2 np.array
|
175 |
+
destination points, each row is a pair of transformed
|
176 |
+
coordinates (x, y)
|
177 |
+
@reflective: True or False
|
178 |
+
if True:
|
179 |
+
use reflective similarity transform
|
180 |
+
else:
|
181 |
+
use non-reflective similarity transform
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
----------
|
185 |
+
@trans: 3x3 np.array
|
186 |
+
transform matrix from uv to xy
|
187 |
+
trans_inv: 3x3 np.array
|
188 |
+
inverse of trans, transform matrix from xy to uv
|
189 |
+
"""
|
190 |
+
|
191 |
+
if reflective:
|
192 |
+
trans, trans_inv = findSimilarity(src_pts, dst_pts)
|
193 |
+
else:
|
194 |
+
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
|
195 |
+
|
196 |
+
return trans, trans_inv
|
197 |
+
|
198 |
+
|
199 |
+
def cvt_tform_mat_for_cv2(trans):
|
200 |
+
"""
|
201 |
+
Function:
|
202 |
+
----------
|
203 |
+
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
|
204 |
+
directly used by cv2.warpAffine():
|
205 |
+
u = src_pts[:, 0]
|
206 |
+
v = src_pts[:, 1]
|
207 |
+
x = dst_pts[:, 0]
|
208 |
+
y = dst_pts[:, 1]
|
209 |
+
[x, y].T = cv_trans * [u, v, 1].T
|
210 |
+
|
211 |
+
Parameters:
|
212 |
+
----------
|
213 |
+
@trans: 3x3 np.array
|
214 |
+
transform matrix from uv to xy
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
----------
|
218 |
+
@cv2_trans: 2x3 np.array
|
219 |
+
transform matrix from src_pts to dst_pts, could be directly used
|
220 |
+
for cv2.warpAffine()
|
221 |
+
"""
|
222 |
+
cv2_trans = trans[:, 0:2].T
|
223 |
+
|
224 |
+
return cv2_trans
|
225 |
+
|
226 |
+
|
227 |
+
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
|
228 |
+
"""
|
229 |
+
Function:
|
230 |
+
----------
|
231 |
+
Find Similarity Transform Matrix 'cv2_trans' which could be
|
232 |
+
directly used by cv2.warpAffine():
|
233 |
+
u = src_pts[:, 0]
|
234 |
+
v = src_pts[:, 1]
|
235 |
+
x = dst_pts[:, 0]
|
236 |
+
y = dst_pts[:, 1]
|
237 |
+
[x, y].T = cv_trans * [u, v, 1].T
|
238 |
+
|
239 |
+
Parameters:
|
240 |
+
----------
|
241 |
+
@src_pts: Kx2 np.array
|
242 |
+
source points, each row is a pair of coordinates (x, y)
|
243 |
+
@dst_pts: Kx2 np.array
|
244 |
+
destination points, each row is a pair of transformed
|
245 |
+
coordinates (x, y)
|
246 |
+
reflective: True or False
|
247 |
+
if True:
|
248 |
+
use reflective similarity transform
|
249 |
+
else:
|
250 |
+
use non-reflective similarity transform
|
251 |
+
|
252 |
+
Returns:
|
253 |
+
----------
|
254 |
+
@cv2_trans: 2x3 np.array
|
255 |
+
transform matrix from src_pts to dst_pts, could be directly used
|
256 |
+
for cv2.warpAffine()
|
257 |
+
"""
|
258 |
+
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
|
259 |
+
cv2_trans = cvt_tform_mat_for_cv2(trans)
|
260 |
+
|
261 |
+
return cv2_trans
|
262 |
+
|
263 |
+
|
264 |
+
if __name__ == '__main__':
|
265 |
+
"""
|
266 |
+
u = [0, 6, -2]
|
267 |
+
v = [0, 3, 5]
|
268 |
+
x = [-1, 0, 4]
|
269 |
+
y = [-1, -10, 4]
|
270 |
+
|
271 |
+
# In Matlab, run:
|
272 |
+
#
|
273 |
+
# uv = [u'; v'];
|
274 |
+
# xy = [x'; y'];
|
275 |
+
# tform_sim=cp2tform(uv,xy,'similarity');
|
276 |
+
#
|
277 |
+
# trans = tform_sim.tdata.T
|
278 |
+
# ans =
|
279 |
+
# -0.0764 -1.6190 0
|
280 |
+
# 1.6190 -0.0764 0
|
281 |
+
# -3.2156 0.0290 1.0000
|
282 |
+
# trans_inv = tform_sim.tdata.Tinv
|
283 |
+
# ans =
|
284 |
+
#
|
285 |
+
# -0.0291 0.6163 0
|
286 |
+
# -0.6163 -0.0291 0
|
287 |
+
# -0.0756 1.9826 1.0000
|
288 |
+
# xy_m=tformfwd(tform_sim, u,v)
|
289 |
+
#
|
290 |
+
# xy_m =
|
291 |
+
#
|
292 |
+
# -3.2156 0.0290
|
293 |
+
# 1.1833 -9.9143
|
294 |
+
# 5.0323 2.8853
|
295 |
+
# uv_m=tforminv(tform_sim, x,y)
|
296 |
+
#
|
297 |
+
# uv_m =
|
298 |
+
#
|
299 |
+
# 0.5698 1.3953
|
300 |
+
# 6.0872 2.2733
|
301 |
+
# -2.6570 4.3314
|
302 |
+
"""
|
303 |
+
u = [0, 6, -2]
|
304 |
+
v = [0, 3, 5]
|
305 |
+
x = [-1, 0, 4]
|
306 |
+
y = [-1, -10, 4]
|
307 |
+
|
308 |
+
uv = np.array((u, v)).T
|
309 |
+
xy = np.array((x, y)).T
|
310 |
+
|
311 |
+
print('\n--->uv:')
|
312 |
+
print(uv)
|
313 |
+
print('\n--->xy:')
|
314 |
+
print(xy)
|
315 |
+
|
316 |
+
trans, trans_inv = get_similarity_transform(uv, xy)
|
317 |
+
|
318 |
+
print('\n--->trans matrix:')
|
319 |
+
print(trans)
|
320 |
+
|
321 |
+
print('\n--->trans_inv matrix:')
|
322 |
+
print(trans_inv)
|
323 |
+
|
324 |
+
print('\n---> apply transform to uv')
|
325 |
+
print('\nxy_m = uv_augmented * trans')
|
326 |
+
uv_aug = np.hstack((
|
327 |
+
uv, np.ones((uv.shape[0], 1))
|
328 |
+
))
|
329 |
+
xy_m = np.dot(uv_aug, trans)
|
330 |
+
print(xy_m)
|
331 |
+
|
332 |
+
print('\nxy_m = tformfwd(trans, uv)')
|
333 |
+
xy_m = tformfwd(trans, uv)
|
334 |
+
print(xy_m)
|
335 |
+
|
336 |
+
print('\n---> apply inverse transform to xy')
|
337 |
+
print('\nuv_m = xy_augmented * trans_inv')
|
338 |
+
xy_aug = np.hstack((
|
339 |
+
xy, np.ones((xy.shape[0], 1))
|
340 |
+
))
|
341 |
+
uv_m = np.dot(xy_aug, trans_inv)
|
342 |
+
print(uv_m)
|
343 |
+
|
344 |
+
print('\nuv_m = tformfwd(trans_inv, xy)')
|
345 |
+
uv_m = tformfwd(trans_inv, xy)
|
346 |
+
print(uv_m)
|
347 |
+
|
348 |
+
uv_m = tforminv(trans, xy)
|
349 |
+
print('\nuv_m = tforminv(trans, xy)')
|
350 |
+
print(uv_m)
|
models/mtcnn/mtcnn_pytorch/src/visualization_utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import ImageDraw
|
2 |
+
|
3 |
+
|
4 |
+
def show_bboxes(img, bounding_boxes, facial_landmarks=[]):
|
5 |
+
"""Draw bounding boxes and facial landmarks.
|
6 |
+
|
7 |
+
Arguments:
|
8 |
+
img: an instance of PIL.Image.
|
9 |
+
bounding_boxes: a float numpy array of shape [n, 5].
|
10 |
+
facial_landmarks: a float numpy array of shape [n, 10].
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
an instance of PIL.Image.
|
14 |
+
"""
|
15 |
+
|
16 |
+
img_copy = img.copy()
|
17 |
+
draw = ImageDraw.Draw(img_copy)
|
18 |
+
|
19 |
+
for b in bounding_boxes:
|
20 |
+
draw.rectangle([
|
21 |
+
(b[0], b[1]), (b[2], b[3])
|
22 |
+
], outline='white')
|
23 |
+
|
24 |
+
for p in facial_landmarks:
|
25 |
+
for i in range(5):
|
26 |
+
draw.ellipse([
|
27 |
+
(p[i] - 1.0, p[i + 5] - 1.0),
|
28 |
+
(p[i] + 1.0, p[i + 5] + 1.0)
|
29 |
+
], outline='blue')
|
30 |
+
|
31 |
+
return img_copy
|
models/mtcnn/mtcnn_pytorch/src/weights/onet.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:313141c3646bebb73cb8350a2d5fee4c7f044fb96304b46ccc21aeea8b818f83
|
3 |
+
size 2345483
|
models/mtcnn/mtcnn_pytorch/src/weights/pnet.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:03e19e5c473932ab38f5a6308fe6210624006994a687e858d1dcda53c66f18cb
|
3 |
+
size 41271
|
models/mtcnn/mtcnn_pytorch/src/weights/rnet.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5660aad67688edc9e8a3dd4e47ed120932835e06a8a711a423252a6f2c747083
|
3 |
+
size 604651
|
models/psp.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file defines the core research contribution
|
3 |
+
"""
|
4 |
+
import matplotlib
|
5 |
+
matplotlib.use('Agg')
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
from models.encoders import psp_encoders
|
11 |
+
from models.stylegan2.model import Generator
|
12 |
+
from configs.paths_config import model_paths
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
def get_keys(d, name):
|
16 |
+
if 'state_dict' in d:
|
17 |
+
d = d['state_dict']
|
18 |
+
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
|
19 |
+
return d_filt
|
20 |
+
|
21 |
+
|
22 |
+
class pSp(nn.Module):
|
23 |
+
|
24 |
+
def __init__(self, opts):
|
25 |
+
super(pSp, self).__init__()
|
26 |
+
self.set_opts(opts)
|
27 |
+
# compute number of style inputs based on the output resolution
|
28 |
+
self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2
|
29 |
+
# Define architecture
|
30 |
+
self.encoder = self.set_encoder()
|
31 |
+
self.decoder = Generator(self.opts.output_size, 512, 8)
|
32 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
33 |
+
# Load weights if needed
|
34 |
+
self.load_weights()
|
35 |
+
|
36 |
+
def set_encoder(self):
|
37 |
+
if self.opts.encoder_type == 'GradualStyleEncoder':
|
38 |
+
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
|
39 |
+
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW':
|
40 |
+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
|
41 |
+
elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus':
|
42 |
+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts)
|
43 |
+
else:
|
44 |
+
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
|
45 |
+
return encoder
|
46 |
+
|
47 |
+
def load_weights(self):
|
48 |
+
if self.opts.checkpoint_path is not None:
|
49 |
+
print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path))
|
50 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
|
51 |
+
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=False)
|
52 |
+
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=False)
|
53 |
+
self.__load_latent_avg(ckpt)
|
54 |
+
else:
|
55 |
+
print('Loading encoders weights from irse50!')
|
56 |
+
encoder_ckpt = torch.load(model_paths['ir_se50'])
|
57 |
+
# if input to encoder is not an RGB image, do not load the input layer weights
|
58 |
+
if self.opts.label_nc != 0:
|
59 |
+
encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k}
|
60 |
+
self.encoder.load_state_dict(encoder_ckpt, strict=False)
|
61 |
+
print('Loading decoder weights from pretrained!')
|
62 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
63 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
64 |
+
if self.opts.learn_in_w:
|
65 |
+
self.__load_latent_avg(ckpt, repeat=1)
|
66 |
+
else:
|
67 |
+
self.__load_latent_avg(ckpt, repeat=self.opts.n_styles)
|
68 |
+
# for video toonification, we load G0' model
|
69 |
+
if self.opts.toonify_weights is not None: ##### modified
|
70 |
+
ckpt = torch.load(self.opts.toonify_weights)
|
71 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
72 |
+
self.opts.toonify_weights = None
|
73 |
+
|
74 |
+
# x1: image for first-layer feature f.
|
75 |
+
# x2: image for style latent code w+. If not specified, x2=x1.
|
76 |
+
# inject_latent: for sketch/mask-to-face translation, another latent code to fuse with w+
|
77 |
+
# latent_mask: fuse w+ and inject_latent with the mask (1~7 use w+ and 8~18 use inject_latent)
|
78 |
+
# use_feature: use f. Otherwise, use the orginal StyleGAN first-layer constant 4*4 feature
|
79 |
+
# first_layer_feature_ind: always=0, means the 1st layer of G accept f
|
80 |
+
# use_skip: use skip connection.
|
81 |
+
# zero_noise: use zero noises.
|
82 |
+
# editing_w: the editing vector v for video face editing
|
83 |
+
def forward(self, x1, x2=None, resize=True, latent_mask=None, randomize_noise=True,
|
84 |
+
inject_latent=None, return_latents=False, alpha=None, use_feature=True,
|
85 |
+
first_layer_feature_ind=0, use_skip=False, zero_noise=False, editing_w=None): ##### modified
|
86 |
+
|
87 |
+
feats = None # f and the skipped encoder features
|
88 |
+
codes, feats = self.encoder(x1, return_feat=True, return_full=use_skip) ##### modified
|
89 |
+
if x2 is not None: ##### modified
|
90 |
+
codes = self.encoder(x2) ##### modified
|
91 |
+
# normalize with respect to the center of an average face
|
92 |
+
if self.opts.start_from_latent_avg:
|
93 |
+
if self.opts.learn_in_w:
|
94 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1)
|
95 |
+
else:
|
96 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
|
97 |
+
|
98 |
+
# E_W^{1:7}(T(x1)) concatenate E_W^{8:18}(w~)
|
99 |
+
if latent_mask is not None:
|
100 |
+
for i in latent_mask:
|
101 |
+
if inject_latent is not None:
|
102 |
+
if alpha is not None:
|
103 |
+
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
|
104 |
+
else:
|
105 |
+
codes[:, i] = inject_latent[:, i]
|
106 |
+
else:
|
107 |
+
codes[:, i] = 0
|
108 |
+
|
109 |
+
first_layer_feats, skip_layer_feats, fusion = None, None, None ##### modified
|
110 |
+
if use_feature: ##### modified
|
111 |
+
first_layer_feats = feats[0:2] # use f
|
112 |
+
if use_skip: ##### modified
|
113 |
+
skip_layer_feats = feats[2:] # use skipped encoder feature
|
114 |
+
fusion = self.encoder.fusion # use fusion layer to fuse encoder feature and decoder feature.
|
115 |
+
|
116 |
+
images, result_latent = self.decoder([codes],
|
117 |
+
input_is_latent=True,
|
118 |
+
randomize_noise=randomize_noise,
|
119 |
+
return_latents=return_latents,
|
120 |
+
first_layer_feature=first_layer_feats,
|
121 |
+
first_layer_feature_ind=first_layer_feature_ind,
|
122 |
+
skip_layer_feature=skip_layer_feats,
|
123 |
+
fusion_block=fusion,
|
124 |
+
zero_noise=zero_noise,
|
125 |
+
editing_w=editing_w) ##### modified
|
126 |
+
|
127 |
+
if resize:
|
128 |
+
if self.opts.output_size == 1024: ##### modified
|
129 |
+
images = F.adaptive_avg_pool2d(images, (images.shape[2]//4, images.shape[3]//4)) ##### modified
|
130 |
+
else:
|
131 |
+
images = self.face_pool(images)
|
132 |
+
|
133 |
+
if return_latents:
|
134 |
+
return images, result_latent
|
135 |
+
else:
|
136 |
+
return images
|
137 |
+
|
138 |
+
def set_opts(self, opts):
|
139 |
+
self.opts = opts
|
140 |
+
|
141 |
+
def __load_latent_avg(self, ckpt, repeat=None):
|
142 |
+
if 'latent_avg' in ckpt:
|
143 |
+
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
|
144 |
+
if repeat is not None:
|
145 |
+
self.latent_avg = self.latent_avg.repeat(repeat, 1)
|
146 |
+
else:
|
147 |
+
self.latent_avg = None
|
models/stylegan2/__init__.py
ADDED
File without changes
|
models/stylegan2/lpips/__init__.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
#from skimage.measure import compare_ssim
|
8 |
+
from skimage.metrics import structural_similarity as compare_ssim
|
9 |
+
import torch
|
10 |
+
from torch.autograd import Variable
|
11 |
+
|
12 |
+
from models.stylegan2.lpips import dist_model
|
13 |
+
|
14 |
+
class PerceptualLoss(torch.nn.Module):
|
15 |
+
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
|
16 |
+
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
|
17 |
+
super(PerceptualLoss, self).__init__()
|
18 |
+
print('Setting up Perceptual loss...')
|
19 |
+
self.use_gpu = use_gpu
|
20 |
+
self.spatial = spatial
|
21 |
+
self.gpu_ids = gpu_ids
|
22 |
+
self.model = dist_model.DistModel()
|
23 |
+
self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
|
24 |
+
print('...[%s] initialized'%self.model.name())
|
25 |
+
print('...Done')
|
26 |
+
|
27 |
+
def forward(self, pred, target, normalize=False):
|
28 |
+
"""
|
29 |
+
Pred and target are Variables.
|
30 |
+
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
|
31 |
+
If normalize is False, assumes the images are already between [-1,+1]
|
32 |
+
|
33 |
+
Inputs pred and target are Nx3xHxW
|
34 |
+
Output pytorch Variable N long
|
35 |
+
"""
|
36 |
+
|
37 |
+
if normalize:
|
38 |
+
target = 2 * target - 1
|
39 |
+
pred = 2 * pred - 1
|
40 |
+
|
41 |
+
return self.model.forward(target, pred)
|
42 |
+
|
43 |
+
def normalize_tensor(in_feat,eps=1e-10):
|
44 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
45 |
+
return in_feat/(norm_factor+eps)
|
46 |
+
|
47 |
+
def l2(p0, p1, range=255.):
|
48 |
+
return .5*np.mean((p0 / range - p1 / range)**2)
|
49 |
+
|
50 |
+
def psnr(p0, p1, peak=255.):
|
51 |
+
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
|
52 |
+
|
53 |
+
def dssim(p0, p1, range=255.):
|
54 |
+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
55 |
+
|
56 |
+
def rgb2lab(in_img,mean_cent=False):
|
57 |
+
from skimage import color
|
58 |
+
img_lab = color.rgb2lab(in_img)
|
59 |
+
if(mean_cent):
|
60 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
61 |
+
return img_lab
|
62 |
+
|
63 |
+
def tensor2np(tensor_obj):
|
64 |
+
# change dimension of a tensor object into a numpy array
|
65 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
66 |
+
|
67 |
+
def np2tensor(np_obj):
|
68 |
+
# change dimenion of np array into tensor array
|
69 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
70 |
+
|
71 |
+
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
72 |
+
# image tensor to lab tensor
|
73 |
+
from skimage import color
|
74 |
+
|
75 |
+
img = tensor2im(image_tensor)
|
76 |
+
img_lab = color.rgb2lab(img)
|
77 |
+
if(mc_only):
|
78 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
79 |
+
if(to_norm and not mc_only):
|
80 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
81 |
+
img_lab = img_lab/100.
|
82 |
+
|
83 |
+
return np2tensor(img_lab)
|
84 |
+
|
85 |
+
def tensorlab2tensor(lab_tensor,return_inbnd=False):
|
86 |
+
from skimage import color
|
87 |
+
import warnings
|
88 |
+
warnings.filterwarnings("ignore")
|
89 |
+
|
90 |
+
lab = tensor2np(lab_tensor)*100.
|
91 |
+
lab[:,:,0] = lab[:,:,0]+50
|
92 |
+
|
93 |
+
rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
|
94 |
+
if(return_inbnd):
|
95 |
+
# convert back to lab, see if we match
|
96 |
+
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
|
97 |
+
mask = 1.*np.isclose(lab_back,lab,atol=2.)
|
98 |
+
mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
|
99 |
+
return (im2tensor(rgb_back),mask)
|
100 |
+
else:
|
101 |
+
return im2tensor(rgb_back)
|
102 |
+
|
103 |
+
def rgb2lab(input):
|
104 |
+
from skimage import color
|
105 |
+
return color.rgb2lab(input / 255.)
|
106 |
+
|
107 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
108 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
109 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
110 |
+
return image_numpy.astype(imtype)
|
111 |
+
|
112 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
113 |
+
return torch.Tensor((image / factor - cent)
|
114 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
115 |
+
|
116 |
+
def tensor2vec(vector_tensor):
|
117 |
+
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
|
118 |
+
|
119 |
+
def voc_ap(rec, prec, use_07_metric=False):
|
120 |
+
""" ap = voc_ap(rec, prec, [use_07_metric])
|
121 |
+
Compute VOC AP given precision and recall.
|
122 |
+
If use_07_metric is true, uses the
|
123 |
+
VOC 07 11 point method (default:False).
|
124 |
+
"""
|
125 |
+
if use_07_metric:
|
126 |
+
# 11 point metric
|
127 |
+
ap = 0.
|
128 |
+
for t in np.arange(0., 1.1, 0.1):
|
129 |
+
if np.sum(rec >= t) == 0:
|
130 |
+
p = 0
|
131 |
+
else:
|
132 |
+
p = np.max(prec[rec >= t])
|
133 |
+
ap = ap + p / 11.
|
134 |
+
else:
|
135 |
+
# correct AP calculation
|
136 |
+
# first append sentinel values at the end
|
137 |
+
mrec = np.concatenate(([0.], rec, [1.]))
|
138 |
+
mpre = np.concatenate(([0.], prec, [0.]))
|
139 |
+
|
140 |
+
# compute the precision envelope
|
141 |
+
for i in range(mpre.size - 1, 0, -1):
|
142 |
+
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
|
143 |
+
|
144 |
+
# to calculate area under PR curve, look for points
|
145 |
+
# where X axis (recall) changes value
|
146 |
+
i = np.where(mrec[1:] != mrec[:-1])[0]
|
147 |
+
|
148 |
+
# and sum (\Delta recall) * prec
|
149 |
+
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
|
150 |
+
return ap
|
151 |
+
|
152 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
153 |
+
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
154 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
155 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
156 |
+
return image_numpy.astype(imtype)
|
157 |
+
|
158 |
+
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
159 |
+
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
160 |
+
return torch.Tensor((image / factor - cent)
|
161 |
+
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
models/stylegan2/lpips/base_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from pdb import set_trace as st
|
6 |
+
from IPython import embed
|
7 |
+
|
8 |
+
class BaseModel():
|
9 |
+
def __init__(self):
|
10 |
+
pass;
|
11 |
+
|
12 |
+
def name(self):
|
13 |
+
return 'BaseModel'
|
14 |
+
|
15 |
+
def initialize(self, use_gpu=True, gpu_ids=[0]):
|
16 |
+
self.use_gpu = use_gpu
|
17 |
+
self.gpu_ids = gpu_ids
|
18 |
+
|
19 |
+
def forward(self):
|
20 |
+
pass
|
21 |
+
|
22 |
+
def get_image_paths(self):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def optimize_parameters(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
def get_current_visuals(self):
|
29 |
+
return self.input
|
30 |
+
|
31 |
+
def get_current_errors(self):
|
32 |
+
return {}
|
33 |
+
|
34 |
+
def save(self, label):
|
35 |
+
pass
|
36 |
+
|
37 |
+
# helper saving function that can be used by subclasses
|
38 |
+
def save_network(self, network, path, network_label, epoch_label):
|
39 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
40 |
+
save_path = os.path.join(path, save_filename)
|
41 |
+
torch.save(network.state_dict(), save_path)
|
42 |
+
|
43 |
+
# helper loading function that can be used by subclasses
|
44 |
+
def load_network(self, network, network_label, epoch_label):
|
45 |
+
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
46 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
47 |
+
print('Loading network from %s'%save_path)
|
48 |
+
network.load_state_dict(torch.load(save_path))
|
49 |
+
|
50 |
+
def update_learning_rate():
|
51 |
+
pass
|
52 |
+
|
53 |
+
def get_image_paths(self):
|
54 |
+
return self.image_paths
|
55 |
+
|
56 |
+
def save_done(self, flag=False):
|
57 |
+
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
|
58 |
+
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
|
models/stylegan2/lpips/dist_model.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import os
|
9 |
+
from collections import OrderedDict
|
10 |
+
from torch.autograd import Variable
|
11 |
+
import itertools
|
12 |
+
from models.stylegan2.lpips.base_model import BaseModel
|
13 |
+
from scipy.ndimage import zoom
|
14 |
+
import fractions
|
15 |
+
import functools
|
16 |
+
import skimage.transform
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
from IPython import embed
|
20 |
+
|
21 |
+
from models.stylegan2.lpips import networks_basic as networks
|
22 |
+
import models.stylegan2.lpips as util
|
23 |
+
|
24 |
+
class DistModel(BaseModel):
|
25 |
+
def name(self):
|
26 |
+
return self.model_name
|
27 |
+
|
28 |
+
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
|
29 |
+
use_gpu=True, printNet=False, spatial=False,
|
30 |
+
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
|
31 |
+
'''
|
32 |
+
INPUTS
|
33 |
+
model - ['net-lin'] for linearly calibrated network
|
34 |
+
['net'] for off-the-shelf network
|
35 |
+
['L2'] for L2 distance in Lab colorspace
|
36 |
+
['SSIM'] for ssim in RGB colorspace
|
37 |
+
net - ['squeeze','alex','vgg']
|
38 |
+
model_path - if None, will look in weights/[NET_NAME].pth
|
39 |
+
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
|
40 |
+
use_gpu - bool - whether or not to use a GPU
|
41 |
+
printNet - bool - whether or not to print network architecture out
|
42 |
+
spatial - bool - whether to output an array containing varying distances across spatial dimensions
|
43 |
+
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
|
44 |
+
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
|
45 |
+
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
|
46 |
+
is_train - bool - [True] for training mode
|
47 |
+
lr - float - initial learning rate
|
48 |
+
beta1 - float - initial momentum term for adam
|
49 |
+
version - 0.1 for latest, 0.0 was original (with a bug)
|
50 |
+
gpu_ids - int array - [0] by default, gpus to use
|
51 |
+
'''
|
52 |
+
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
|
53 |
+
|
54 |
+
self.model = model
|
55 |
+
self.net = net
|
56 |
+
self.is_train = is_train
|
57 |
+
self.spatial = spatial
|
58 |
+
self.gpu_ids = gpu_ids
|
59 |
+
self.model_name = '%s [%s]'%(model,net)
|
60 |
+
|
61 |
+
if(self.model == 'net-lin'): # pretrained net + linear layer
|
62 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
|
63 |
+
use_dropout=True, spatial=spatial, version=version, lpips=True)
|
64 |
+
kw = {}
|
65 |
+
if not use_gpu:
|
66 |
+
kw['map_location'] = 'cpu'
|
67 |
+
if(model_path is None):
|
68 |
+
import inspect
|
69 |
+
model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
|
70 |
+
|
71 |
+
if(not is_train):
|
72 |
+
print('Loading model from: %s'%model_path)
|
73 |
+
self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
|
74 |
+
|
75 |
+
elif(self.model=='net'): # pretrained network
|
76 |
+
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
|
77 |
+
elif(self.model in ['L2','l2']):
|
78 |
+
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
|
79 |
+
self.model_name = 'L2'
|
80 |
+
elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
|
81 |
+
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
|
82 |
+
self.model_name = 'SSIM'
|
83 |
+
else:
|
84 |
+
raise ValueError("Model [%s] not recognized." % self.model)
|
85 |
+
|
86 |
+
self.parameters = list(self.net.parameters())
|
87 |
+
|
88 |
+
if self.is_train: # training mode
|
89 |
+
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
|
90 |
+
self.rankLoss = networks.BCERankingLoss()
|
91 |
+
self.parameters += list(self.rankLoss.net.parameters())
|
92 |
+
self.lr = lr
|
93 |
+
self.old_lr = lr
|
94 |
+
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
|
95 |
+
else: # test mode
|
96 |
+
self.net.eval()
|
97 |
+
|
98 |
+
if(use_gpu):
|
99 |
+
self.net.to(gpu_ids[0])
|
100 |
+
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
|
101 |
+
if(self.is_train):
|
102 |
+
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
|
103 |
+
|
104 |
+
if(printNet):
|
105 |
+
print('---------- Networks initialized -------------')
|
106 |
+
networks.print_network(self.net)
|
107 |
+
print('-----------------------------------------------')
|
108 |
+
|
109 |
+
def forward(self, in0, in1, retPerLayer=False):
|
110 |
+
''' Function computes the distance between image patches in0 and in1
|
111 |
+
INPUTS
|
112 |
+
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
|
113 |
+
OUTPUT
|
114 |
+
computed distances between in0 and in1
|
115 |
+
'''
|
116 |
+
|
117 |
+
return self.net.forward(in0, in1, retPerLayer=retPerLayer)
|
118 |
+
|
119 |
+
# ***** TRAINING FUNCTIONS *****
|
120 |
+
def optimize_parameters(self):
|
121 |
+
self.forward_train()
|
122 |
+
self.optimizer_net.zero_grad()
|
123 |
+
self.backward_train()
|
124 |
+
self.optimizer_net.step()
|
125 |
+
self.clamp_weights()
|
126 |
+
|
127 |
+
def clamp_weights(self):
|
128 |
+
for module in self.net.modules():
|
129 |
+
if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
|
130 |
+
module.weight.data = torch.clamp(module.weight.data,min=0)
|
131 |
+
|
132 |
+
def set_input(self, data):
|
133 |
+
self.input_ref = data['ref']
|
134 |
+
self.input_p0 = data['p0']
|
135 |
+
self.input_p1 = data['p1']
|
136 |
+
self.input_judge = data['judge']
|
137 |
+
|
138 |
+
if(self.use_gpu):
|
139 |
+
self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
|
140 |
+
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
|
141 |
+
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
|
142 |
+
self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
|
143 |
+
|
144 |
+
self.var_ref = Variable(self.input_ref,requires_grad=True)
|
145 |
+
self.var_p0 = Variable(self.input_p0,requires_grad=True)
|
146 |
+
self.var_p1 = Variable(self.input_p1,requires_grad=True)
|
147 |
+
|
148 |
+
def forward_train(self): # run forward pass
|
149 |
+
# print(self.net.module.scaling_layer.shift)
|
150 |
+
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
|
151 |
+
|
152 |
+
self.d0 = self.forward(self.var_ref, self.var_p0)
|
153 |
+
self.d1 = self.forward(self.var_ref, self.var_p1)
|
154 |
+
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
|
155 |
+
|
156 |
+
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
|
157 |
+
|
158 |
+
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
|
159 |
+
|
160 |
+
return self.loss_total
|
161 |
+
|
162 |
+
def backward_train(self):
|
163 |
+
torch.mean(self.loss_total).backward()
|
164 |
+
|
165 |
+
def compute_accuracy(self,d0,d1,judge):
|
166 |
+
''' d0, d1 are Variables, judge is a Tensor '''
|
167 |
+
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
|
168 |
+
judge_per = judge.cpu().numpy().flatten()
|
169 |
+
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
|
170 |
+
|
171 |
+
def get_current_errors(self):
|
172 |
+
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
|
173 |
+
('acc_r', self.acc_r)])
|
174 |
+
|
175 |
+
for key in retDict.keys():
|
176 |
+
retDict[key] = np.mean(retDict[key])
|
177 |
+
|
178 |
+
return retDict
|
179 |
+
|
180 |
+
def get_current_visuals(self):
|
181 |
+
zoom_factor = 256/self.var_ref.data.size()[2]
|
182 |
+
|
183 |
+
ref_img = util.tensor2im(self.var_ref.data)
|
184 |
+
p0_img = util.tensor2im(self.var_p0.data)
|
185 |
+
p1_img = util.tensor2im(self.var_p1.data)
|
186 |
+
|
187 |
+
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
|
188 |
+
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
|
189 |
+
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
|
190 |
+
|
191 |
+
return OrderedDict([('ref', ref_img_vis),
|
192 |
+
('p0', p0_img_vis),
|
193 |
+
('p1', p1_img_vis)])
|
194 |
+
|
195 |
+
def save(self, path, label):
|
196 |
+
if(self.use_gpu):
|
197 |
+
self.save_network(self.net.module, path, '', label)
|
198 |
+
else:
|
199 |
+
self.save_network(self.net, path, '', label)
|
200 |
+
self.save_network(self.rankLoss.net, path, 'rank', label)
|
201 |
+
|
202 |
+
def update_learning_rate(self,nepoch_decay):
|
203 |
+
lrd = self.lr / nepoch_decay
|
204 |
+
lr = self.old_lr - lrd
|
205 |
+
|
206 |
+
for param_group in self.optimizer_net.param_groups:
|
207 |
+
param_group['lr'] = lr
|
208 |
+
|
209 |
+
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
|
210 |
+
self.old_lr = lr
|
211 |
+
|
212 |
+
def score_2afc_dataset(data_loader, func, name=''):
|
213 |
+
''' Function computes Two Alternative Forced Choice (2AFC) score using
|
214 |
+
distance function 'func' in dataset 'data_loader'
|
215 |
+
INPUTS
|
216 |
+
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
|
217 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
218 |
+
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
|
219 |
+
OUTPUTS
|
220 |
+
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
|
221 |
+
[1] - dictionary with following elements
|
222 |
+
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
|
223 |
+
gts - N array in [0,1], preferred patch selected by human evaluators
|
224 |
+
(closer to "0" for left patch p0, "1" for right patch p1,
|
225 |
+
"0.6" means 60pct people preferred right patch, 40pct preferred left)
|
226 |
+
scores - N array in [0,1], corresponding to what percentage function agreed with humans
|
227 |
+
CONSTS
|
228 |
+
N - number of test triplets in data_loader
|
229 |
+
'''
|
230 |
+
|
231 |
+
d0s = []
|
232 |
+
d1s = []
|
233 |
+
gts = []
|
234 |
+
|
235 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
236 |
+
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
|
237 |
+
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
|
238 |
+
gts+=data['judge'].cpu().numpy().flatten().tolist()
|
239 |
+
|
240 |
+
d0s = np.array(d0s)
|
241 |
+
d1s = np.array(d1s)
|
242 |
+
gts = np.array(gts)
|
243 |
+
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
|
244 |
+
|
245 |
+
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
|
246 |
+
|
247 |
+
def score_jnd_dataset(data_loader, func, name=''):
|
248 |
+
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
|
249 |
+
INPUTS
|
250 |
+
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
|
251 |
+
func - callable distance function - calling d=func(in0,in1) should take 2
|
252 |
+
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
|
253 |
+
OUTPUTS
|
254 |
+
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
|
255 |
+
[1] - dictionary with following elements
|
256 |
+
ds - N array containing distances between two patches shown to human evaluator
|
257 |
+
sames - N array containing fraction of people who thought the two patches were identical
|
258 |
+
CONSTS
|
259 |
+
N - number of test triplets in data_loader
|
260 |
+
'''
|
261 |
+
|
262 |
+
ds = []
|
263 |
+
gts = []
|
264 |
+
|
265 |
+
for data in tqdm(data_loader.load_data(), desc=name):
|
266 |
+
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist()
|
267 |
+
gts+=data['same'].cpu().numpy().flatten().tolist()
|
268 |
+
|
269 |
+
sames = np.array(gts)
|
270 |
+
ds = np.array(ds)
|
271 |
+
|
272 |
+
sorted_inds = np.argsort(ds)
|
273 |
+
ds_sorted = ds[sorted_inds]
|
274 |
+
sames_sorted = sames[sorted_inds]
|
275 |
+
|
276 |
+
TPs = np.cumsum(sames_sorted)
|
277 |
+
FPs = np.cumsum(1-sames_sorted)
|
278 |
+
FNs = np.sum(sames_sorted)-TPs
|
279 |
+
|
280 |
+
precs = TPs/(TPs+FPs)
|
281 |
+
recs = TPs/(TPs+FNs)
|
282 |
+
score = util.voc_ap(recs,precs)
|
283 |
+
|
284 |
+
return(score, dict(ds=ds,sames=sames))
|
models/stylegan2/lpips/networks_basic.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.init as init
|
8 |
+
from torch.autograd import Variable
|
9 |
+
import numpy as np
|
10 |
+
from pdb import set_trace as st
|
11 |
+
from skimage import color
|
12 |
+
from IPython import embed
|
13 |
+
from models.stylegan2.lpips import pretrained_networks as pn
|
14 |
+
|
15 |
+
import models.stylegan2.lpips as util
|
16 |
+
|
17 |
+
def spatial_average(in_tens, keepdim=True):
|
18 |
+
return in_tens.mean([2,3],keepdim=keepdim)
|
19 |
+
|
20 |
+
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
|
21 |
+
in_H = in_tens.shape[2]
|
22 |
+
scale_factor = 1.*out_H/in_H
|
23 |
+
|
24 |
+
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
|
25 |
+
|
26 |
+
# Learned perceptual metric
|
27 |
+
class PNetLin(nn.Module):
|
28 |
+
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
|
29 |
+
super(PNetLin, self).__init__()
|
30 |
+
|
31 |
+
self.pnet_type = pnet_type
|
32 |
+
self.pnet_tune = pnet_tune
|
33 |
+
self.pnet_rand = pnet_rand
|
34 |
+
self.spatial = spatial
|
35 |
+
self.lpips = lpips
|
36 |
+
self.version = version
|
37 |
+
self.scaling_layer = ScalingLayer()
|
38 |
+
|
39 |
+
if(self.pnet_type in ['vgg','vgg16']):
|
40 |
+
net_type = pn.vgg16
|
41 |
+
self.chns = [64,128,256,512,512]
|
42 |
+
elif(self.pnet_type=='alex'):
|
43 |
+
net_type = pn.alexnet
|
44 |
+
self.chns = [64,192,384,256,256]
|
45 |
+
elif(self.pnet_type=='squeeze'):
|
46 |
+
net_type = pn.squeezenet
|
47 |
+
self.chns = [64,128,256,384,384,512,512]
|
48 |
+
self.L = len(self.chns)
|
49 |
+
|
50 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
51 |
+
|
52 |
+
if(lpips):
|
53 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
54 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
55 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
56 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
57 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
58 |
+
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
59 |
+
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
60 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
61 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
62 |
+
self.lins+=[self.lin5,self.lin6]
|
63 |
+
|
64 |
+
def forward(self, in0, in1, retPerLayer=False):
|
65 |
+
# v0.0 - original release had a bug, where input was not scaled
|
66 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
67 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
68 |
+
feats0, feats1, diffs = {}, {}, {}
|
69 |
+
|
70 |
+
for kk in range(self.L):
|
71 |
+
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
|
72 |
+
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
73 |
+
|
74 |
+
if(self.lpips):
|
75 |
+
if(self.spatial):
|
76 |
+
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
|
77 |
+
else:
|
78 |
+
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
|
79 |
+
else:
|
80 |
+
if(self.spatial):
|
81 |
+
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
|
82 |
+
else:
|
83 |
+
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
84 |
+
|
85 |
+
val = res[0]
|
86 |
+
for l in range(1,self.L):
|
87 |
+
val += res[l]
|
88 |
+
|
89 |
+
if(retPerLayer):
|
90 |
+
return (val, res)
|
91 |
+
else:
|
92 |
+
return val
|
93 |
+
|
94 |
+
class ScalingLayer(nn.Module):
|
95 |
+
def __init__(self):
|
96 |
+
super(ScalingLayer, self).__init__()
|
97 |
+
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
98 |
+
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
99 |
+
|
100 |
+
def forward(self, inp):
|
101 |
+
return (inp - self.shift) / self.scale
|
102 |
+
|
103 |
+
|
104 |
+
class NetLinLayer(nn.Module):
|
105 |
+
''' A single linear layer which does a 1x1 conv '''
|
106 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
107 |
+
super(NetLinLayer, self).__init__()
|
108 |
+
|
109 |
+
layers = [nn.Dropout(),] if(use_dropout) else []
|
110 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
111 |
+
self.model = nn.Sequential(*layers)
|
112 |
+
|
113 |
+
|
114 |
+
class Dist2LogitLayer(nn.Module):
|
115 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
116 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
117 |
+
super(Dist2LogitLayer, self).__init__()
|
118 |
+
|
119 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
120 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
121 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
122 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
123 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
124 |
+
if(use_sigmoid):
|
125 |
+
layers += [nn.Sigmoid(),]
|
126 |
+
self.model = nn.Sequential(*layers)
|
127 |
+
|
128 |
+
def forward(self,d0,d1,eps=0.1):
|
129 |
+
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
130 |
+
|
131 |
+
class BCERankingLoss(nn.Module):
|
132 |
+
def __init__(self, chn_mid=32):
|
133 |
+
super(BCERankingLoss, self).__init__()
|
134 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
135 |
+
# self.parameters = list(self.net.parameters())
|
136 |
+
self.loss = torch.nn.BCELoss()
|
137 |
+
|
138 |
+
def forward(self, d0, d1, judge):
|
139 |
+
per = (judge+1.)/2.
|
140 |
+
self.logit = self.net.forward(d0,d1)
|
141 |
+
return self.loss(self.logit, per)
|
142 |
+
|
143 |
+
# L2, DSSIM metrics
|
144 |
+
class FakeNet(nn.Module):
|
145 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
146 |
+
super(FakeNet, self).__init__()
|
147 |
+
self.use_gpu = use_gpu
|
148 |
+
self.colorspace=colorspace
|
149 |
+
|
150 |
+
class L2(FakeNet):
|
151 |
+
|
152 |
+
def forward(self, in0, in1, retPerLayer=None):
|
153 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
154 |
+
|
155 |
+
if(self.colorspace=='RGB'):
|
156 |
+
(N,C,X,Y) = in0.size()
|
157 |
+
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
158 |
+
return value
|
159 |
+
elif(self.colorspace=='Lab'):
|
160 |
+
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
161 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
162 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
163 |
+
if(self.use_gpu):
|
164 |
+
ret_var = ret_var.cuda()
|
165 |
+
return ret_var
|
166 |
+
|
167 |
+
class DSSIM(FakeNet):
|
168 |
+
|
169 |
+
def forward(self, in0, in1, retPerLayer=None):
|
170 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
171 |
+
|
172 |
+
if(self.colorspace=='RGB'):
|
173 |
+
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
|
174 |
+
elif(self.colorspace=='Lab'):
|
175 |
+
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
|
176 |
+
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
177 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
178 |
+
if(self.use_gpu):
|
179 |
+
ret_var = ret_var.cuda()
|
180 |
+
return ret_var
|
181 |
+
|
182 |
+
def print_network(net):
|
183 |
+
num_params = 0
|
184 |
+
for param in net.parameters():
|
185 |
+
num_params += param.numel()
|
186 |
+
print('Network',net)
|
187 |
+
print('Total number of parameters: %d' % num_params)
|
models/stylegan2/lpips/pretrained_networks.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torchvision import models as tv
|
4 |
+
from IPython import embed
|
5 |
+
|
6 |
+
class squeezenet(torch.nn.Module):
|
7 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
8 |
+
super(squeezenet, self).__init__()
|
9 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
10 |
+
self.slice1 = torch.nn.Sequential()
|
11 |
+
self.slice2 = torch.nn.Sequential()
|
12 |
+
self.slice3 = torch.nn.Sequential()
|
13 |
+
self.slice4 = torch.nn.Sequential()
|
14 |
+
self.slice5 = torch.nn.Sequential()
|
15 |
+
self.slice6 = torch.nn.Sequential()
|
16 |
+
self.slice7 = torch.nn.Sequential()
|
17 |
+
self.N_slices = 7
|
18 |
+
for x in range(2):
|
19 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
20 |
+
for x in range(2,5):
|
21 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
22 |
+
for x in range(5, 8):
|
23 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
24 |
+
for x in range(8, 10):
|
25 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
26 |
+
for x in range(10, 11):
|
27 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
28 |
+
for x in range(11, 12):
|
29 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
30 |
+
for x in range(12, 13):
|
31 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
32 |
+
if not requires_grad:
|
33 |
+
for param in self.parameters():
|
34 |
+
param.requires_grad = False
|
35 |
+
|
36 |
+
def forward(self, X):
|
37 |
+
h = self.slice1(X)
|
38 |
+
h_relu1 = h
|
39 |
+
h = self.slice2(h)
|
40 |
+
h_relu2 = h
|
41 |
+
h = self.slice3(h)
|
42 |
+
h_relu3 = h
|
43 |
+
h = self.slice4(h)
|
44 |
+
h_relu4 = h
|
45 |
+
h = self.slice5(h)
|
46 |
+
h_relu5 = h
|
47 |
+
h = self.slice6(h)
|
48 |
+
h_relu6 = h
|
49 |
+
h = self.slice7(h)
|
50 |
+
h_relu7 = h
|
51 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
52 |
+
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class alexnet(torch.nn.Module):
|
58 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
59 |
+
super(alexnet, self).__init__()
|
60 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
61 |
+
self.slice1 = torch.nn.Sequential()
|
62 |
+
self.slice2 = torch.nn.Sequential()
|
63 |
+
self.slice3 = torch.nn.Sequential()
|
64 |
+
self.slice4 = torch.nn.Sequential()
|
65 |
+
self.slice5 = torch.nn.Sequential()
|
66 |
+
self.N_slices = 5
|
67 |
+
for x in range(2):
|
68 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
69 |
+
for x in range(2, 5):
|
70 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
71 |
+
for x in range(5, 8):
|
72 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
73 |
+
for x in range(8, 10):
|
74 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
75 |
+
for x in range(10, 12):
|
76 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
77 |
+
if not requires_grad:
|
78 |
+
for param in self.parameters():
|
79 |
+
param.requires_grad = False
|
80 |
+
|
81 |
+
def forward(self, X):
|
82 |
+
h = self.slice1(X)
|
83 |
+
h_relu1 = h
|
84 |
+
h = self.slice2(h)
|
85 |
+
h_relu2 = h
|
86 |
+
h = self.slice3(h)
|
87 |
+
h_relu3 = h
|
88 |
+
h = self.slice4(h)
|
89 |
+
h_relu4 = h
|
90 |
+
h = self.slice5(h)
|
91 |
+
h_relu5 = h
|
92 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
93 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
94 |
+
|
95 |
+
return out
|
96 |
+
|
97 |
+
class vgg16(torch.nn.Module):
|
98 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
99 |
+
super(vgg16, self).__init__()
|
100 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
101 |
+
self.slice1 = torch.nn.Sequential()
|
102 |
+
self.slice2 = torch.nn.Sequential()
|
103 |
+
self.slice3 = torch.nn.Sequential()
|
104 |
+
self.slice4 = torch.nn.Sequential()
|
105 |
+
self.slice5 = torch.nn.Sequential()
|
106 |
+
self.N_slices = 5
|
107 |
+
for x in range(4):
|
108 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
109 |
+
for x in range(4, 9):
|
110 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
111 |
+
for x in range(9, 16):
|
112 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
113 |
+
for x in range(16, 23):
|
114 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
115 |
+
for x in range(23, 30):
|
116 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
117 |
+
if not requires_grad:
|
118 |
+
for param in self.parameters():
|
119 |
+
param.requires_grad = False
|
120 |
+
|
121 |
+
def forward(self, X):
|
122 |
+
h = self.slice1(X)
|
123 |
+
h_relu1_2 = h
|
124 |
+
h = self.slice2(h)
|
125 |
+
h_relu2_2 = h
|
126 |
+
h = self.slice3(h)
|
127 |
+
h_relu3_3 = h
|
128 |
+
h = self.slice4(h)
|
129 |
+
h_relu4_3 = h
|
130 |
+
h = self.slice5(h)
|
131 |
+
h_relu5_3 = h
|
132 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
133 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
134 |
+
|
135 |
+
return out
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
class resnet(torch.nn.Module):
|
140 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
141 |
+
super(resnet, self).__init__()
|
142 |
+
if(num==18):
|
143 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
144 |
+
elif(num==34):
|
145 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
146 |
+
elif(num==50):
|
147 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
148 |
+
elif(num==101):
|
149 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
150 |
+
elif(num==152):
|
151 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
152 |
+
self.N_slices = 5
|
153 |
+
|
154 |
+
self.conv1 = self.net.conv1
|
155 |
+
self.bn1 = self.net.bn1
|
156 |
+
self.relu = self.net.relu
|
157 |
+
self.maxpool = self.net.maxpool
|
158 |
+
self.layer1 = self.net.layer1
|
159 |
+
self.layer2 = self.net.layer2
|
160 |
+
self.layer3 = self.net.layer3
|
161 |
+
self.layer4 = self.net.layer4
|
162 |
+
|
163 |
+
def forward(self, X):
|
164 |
+
h = self.conv1(X)
|
165 |
+
h = self.bn1(h)
|
166 |
+
h = self.relu(h)
|
167 |
+
h_relu1 = h
|
168 |
+
h = self.maxpool(h)
|
169 |
+
h = self.layer1(h)
|
170 |
+
h_conv2 = h
|
171 |
+
h = self.layer2(h)
|
172 |
+
h_conv3 = h
|
173 |
+
h = self.layer3(h)
|
174 |
+
h_conv4 = h
|
175 |
+
h = self.layer4(h)
|
176 |
+
h_conv5 = h
|
177 |
+
|
178 |
+
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
179 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
180 |
+
|
181 |
+
return out
|
models/stylegan2/lpips/weights/v0.0/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5
|
3 |
+
size 5455
|
models/stylegan2/lpips/weights/v0.0/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf
|
3 |
+
size 10057
|
models/stylegan2/lpips/weights/v0.0/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c
|
3 |
+
size 6735
|
models/stylegan2/lpips/weights/v0.1/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
|
3 |
+
size 6009
|
models/stylegan2/lpips/weights/v0.1/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
|
3 |
+
size 10811
|
models/stylegan2/lpips/weights/v0.1/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|
models/stylegan2/model.py
ADDED
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
9 |
+
|
10 |
+
|
11 |
+
class PixelNorm(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
def forward(self, input):
|
16 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
17 |
+
|
18 |
+
|
19 |
+
def make_kernel(k):
|
20 |
+
k = torch.tensor(k, dtype=torch.float32)
|
21 |
+
|
22 |
+
if k.ndim == 1:
|
23 |
+
k = k[None, :] * k[:, None]
|
24 |
+
|
25 |
+
k /= k.sum()
|
26 |
+
|
27 |
+
return k
|
28 |
+
|
29 |
+
|
30 |
+
class Upsample(nn.Module):
|
31 |
+
def __init__(self, kernel, factor=2):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
self.factor = factor
|
35 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
36 |
+
self.register_buffer('kernel', kernel)
|
37 |
+
|
38 |
+
p = kernel.shape[0] - factor
|
39 |
+
|
40 |
+
pad0 = (p + 1) // 2 + factor - 1
|
41 |
+
pad1 = p // 2
|
42 |
+
|
43 |
+
self.pad = (pad0, pad1)
|
44 |
+
|
45 |
+
def forward(self, input):
|
46 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
47 |
+
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
class Downsample(nn.Module):
|
52 |
+
def __init__(self, kernel, factor=2):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
self.factor = factor
|
56 |
+
kernel = make_kernel(kernel)
|
57 |
+
self.register_buffer('kernel', kernel)
|
58 |
+
|
59 |
+
p = kernel.shape[0] - factor
|
60 |
+
|
61 |
+
pad0 = (p + 1) // 2
|
62 |
+
pad1 = p // 2
|
63 |
+
|
64 |
+
self.pad = (pad0, pad1)
|
65 |
+
|
66 |
+
def forward(self, input):
|
67 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
68 |
+
|
69 |
+
return out
|
70 |
+
|
71 |
+
|
72 |
+
class Blur(nn.Module):
|
73 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
kernel = make_kernel(kernel)
|
77 |
+
|
78 |
+
if upsample_factor > 1:
|
79 |
+
kernel = kernel * (upsample_factor ** 2)
|
80 |
+
|
81 |
+
self.register_buffer('kernel', kernel)
|
82 |
+
|
83 |
+
self.pad = pad
|
84 |
+
|
85 |
+
def forward(self, input):
|
86 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
87 |
+
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
class EqualConv2d(nn.Module):
|
92 |
+
def __init__(
|
93 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dilation=1 ## modified
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
self.weight = nn.Parameter(
|
98 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
99 |
+
)
|
100 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
101 |
+
|
102 |
+
self.stride = stride
|
103 |
+
self.padding = padding
|
104 |
+
self.dilation = dilation ## modified
|
105 |
+
|
106 |
+
if bias:
|
107 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
108 |
+
|
109 |
+
else:
|
110 |
+
self.bias = None
|
111 |
+
|
112 |
+
def forward(self, input):
|
113 |
+
out = F.conv2d(
|
114 |
+
input,
|
115 |
+
self.weight * self.scale,
|
116 |
+
bias=self.bias,
|
117 |
+
stride=self.stride,
|
118 |
+
padding=self.padding,
|
119 |
+
dilation=self.dilation, ## modified
|
120 |
+
)
|
121 |
+
|
122 |
+
return out
|
123 |
+
|
124 |
+
def __repr__(self):
|
125 |
+
return (
|
126 |
+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
127 |
+
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding}, dilation={self.dilation})" ## modified
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
class EqualLinear(nn.Module):
|
132 |
+
def __init__(
|
133 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
134 |
+
):
|
135 |
+
super().__init__()
|
136 |
+
|
137 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
138 |
+
|
139 |
+
if bias:
|
140 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
141 |
+
|
142 |
+
else:
|
143 |
+
self.bias = None
|
144 |
+
|
145 |
+
self.activation = activation
|
146 |
+
|
147 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
148 |
+
self.lr_mul = lr_mul
|
149 |
+
|
150 |
+
def forward(self, input):
|
151 |
+
if self.activation:
|
152 |
+
out = F.linear(input, self.weight * self.scale)
|
153 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
154 |
+
|
155 |
+
else:
|
156 |
+
out = F.linear(
|
157 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
158 |
+
)
|
159 |
+
|
160 |
+
return out
|
161 |
+
|
162 |
+
def __repr__(self):
|
163 |
+
return (
|
164 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
165 |
+
)
|
166 |
+
|
167 |
+
|
168 |
+
class ScaledLeakyReLU(nn.Module):
|
169 |
+
def __init__(self, negative_slope=0.2):
|
170 |
+
super().__init__()
|
171 |
+
|
172 |
+
self.negative_slope = negative_slope
|
173 |
+
|
174 |
+
def forward(self, input):
|
175 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
176 |
+
|
177 |
+
return out * math.sqrt(2)
|
178 |
+
|
179 |
+
|
180 |
+
class ModulatedConv2d(nn.Module):
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
in_channel,
|
184 |
+
out_channel,
|
185 |
+
kernel_size,
|
186 |
+
style_dim,
|
187 |
+
demodulate=True,
|
188 |
+
upsample=False,
|
189 |
+
downsample=False,
|
190 |
+
blur_kernel=[1, 3, 3, 1],
|
191 |
+
dilation=1, ##### modified
|
192 |
+
):
|
193 |
+
super().__init__()
|
194 |
+
|
195 |
+
self.eps = 1e-8
|
196 |
+
self.kernel_size = kernel_size
|
197 |
+
self.in_channel = in_channel
|
198 |
+
self.out_channel = out_channel
|
199 |
+
self.upsample = upsample
|
200 |
+
self.downsample = downsample
|
201 |
+
self.dilation = dilation ##### modified
|
202 |
+
|
203 |
+
if upsample:
|
204 |
+
factor = 2
|
205 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
206 |
+
pad0 = (p + 1) // 2 + factor - 1
|
207 |
+
pad1 = p // 2 + 1
|
208 |
+
|
209 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
210 |
+
|
211 |
+
# to simulate transconv + blur
|
212 |
+
# we use dilated transposed conv with blur kernel as weight + dilated transconv
|
213 |
+
if dilation > 1: ##### modified
|
214 |
+
blur_weight = torch.randn(1, 1, 3, 3) * 0 + 1
|
215 |
+
blur_weight[:,:,0,1] = 2
|
216 |
+
blur_weight[:,:,1,0] = 2
|
217 |
+
blur_weight[:,:,1,2] = 2
|
218 |
+
blur_weight[:,:,2,1] = 2
|
219 |
+
blur_weight[:,:,1,1] = 4
|
220 |
+
blur_weight = blur_weight / 16.0
|
221 |
+
self.register_buffer("blur_weight", blur_weight)
|
222 |
+
|
223 |
+
if downsample:
|
224 |
+
factor = 2
|
225 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
226 |
+
pad0 = (p + 1) // 2
|
227 |
+
pad1 = p // 2
|
228 |
+
|
229 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
230 |
+
|
231 |
+
fan_in = in_channel * kernel_size ** 2
|
232 |
+
self.scale = 1 / math.sqrt(fan_in)
|
233 |
+
self.padding = kernel_size // 2 + dilation - 1 ##### modified
|
234 |
+
|
235 |
+
self.weight = nn.Parameter(
|
236 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
237 |
+
)
|
238 |
+
|
239 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
240 |
+
|
241 |
+
self.demodulate = demodulate
|
242 |
+
|
243 |
+
def __repr__(self):
|
244 |
+
return (
|
245 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
246 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
247 |
+
)
|
248 |
+
|
249 |
+
def forward(self, input, style):
|
250 |
+
batch, in_channel, height, width = input.shape
|
251 |
+
|
252 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
253 |
+
weight = self.scale * self.weight * style
|
254 |
+
|
255 |
+
if self.demodulate:
|
256 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
257 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
258 |
+
|
259 |
+
weight = weight.view(
|
260 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
261 |
+
)
|
262 |
+
|
263 |
+
if self.upsample:
|
264 |
+
input = input.view(1, batch * in_channel, height, width)
|
265 |
+
weight = weight.view(
|
266 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
267 |
+
)
|
268 |
+
weight = weight.transpose(1, 2).reshape(
|
269 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
270 |
+
)
|
271 |
+
|
272 |
+
if self.dilation > 1: ##### modified
|
273 |
+
# to simulate out = self.blur(out)
|
274 |
+
out = F.conv_transpose2d(
|
275 |
+
input, self.blur_weight.repeat(batch*in_channel,1,1,1), padding=0, groups=batch*in_channel, dilation=self.dilation//2)
|
276 |
+
# to simulate the next line
|
277 |
+
out = F.conv_transpose2d(
|
278 |
+
out, weight, padding=self.dilation, groups=batch, dilation=self.dilation//2)
|
279 |
+
_, _, height, width = out.shape
|
280 |
+
out = out.view(batch, self.out_channel, height, width)
|
281 |
+
return out
|
282 |
+
|
283 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
284 |
+
_, _, height, width = out.shape
|
285 |
+
out = out.view(batch, self.out_channel, height, width)
|
286 |
+
out = self.blur(out)
|
287 |
+
|
288 |
+
elif self.downsample:
|
289 |
+
input = self.blur(input)
|
290 |
+
_, _, height, width = input.shape
|
291 |
+
input = input.view(1, batch * in_channel, height, width)
|
292 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
293 |
+
_, _, height, width = out.shape
|
294 |
+
out = out.view(batch, self.out_channel, height, width)
|
295 |
+
|
296 |
+
else:
|
297 |
+
input = input.view(1, batch * in_channel, height, width)
|
298 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch, dilation=self.dilation) ##### modified
|
299 |
+
_, _, height, width = out.shape
|
300 |
+
out = out.view(batch, self.out_channel, height, width)
|
301 |
+
|
302 |
+
return out
|
303 |
+
|
304 |
+
|
305 |
+
class NoiseInjection(nn.Module):
|
306 |
+
def __init__(self):
|
307 |
+
super().__init__()
|
308 |
+
|
309 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
310 |
+
|
311 |
+
def forward(self, image, noise=None):
|
312 |
+
if noise is None:
|
313 |
+
batch, _, height, width = image.shape
|
314 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
315 |
+
else: ##### modified, to make the resolution matches
|
316 |
+
batch, _, height, width = image.shape
|
317 |
+
_, _, height1, width1 = noise.shape
|
318 |
+
if height != height1 or width != width1:
|
319 |
+
noise = F.adaptive_avg_pool2d(noise, (height, width))
|
320 |
+
|
321 |
+
return image + self.weight * noise
|
322 |
+
|
323 |
+
|
324 |
+
class ConstantInput(nn.Module):
|
325 |
+
def __init__(self, channel, size=4):
|
326 |
+
super().__init__()
|
327 |
+
|
328 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
329 |
+
|
330 |
+
def forward(self, input):
|
331 |
+
batch = input.shape[0]
|
332 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
333 |
+
|
334 |
+
return out
|
335 |
+
|
336 |
+
|
337 |
+
class StyledConv(nn.Module):
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
in_channel,
|
341 |
+
out_channel,
|
342 |
+
kernel_size,
|
343 |
+
style_dim,
|
344 |
+
upsample=False,
|
345 |
+
blur_kernel=[1, 3, 3, 1],
|
346 |
+
demodulate=True,
|
347 |
+
dilation=1, ##### modified
|
348 |
+
):
|
349 |
+
super().__init__()
|
350 |
+
|
351 |
+
self.conv = ModulatedConv2d(
|
352 |
+
in_channel,
|
353 |
+
out_channel,
|
354 |
+
kernel_size,
|
355 |
+
style_dim,
|
356 |
+
upsample=upsample,
|
357 |
+
blur_kernel=blur_kernel,
|
358 |
+
demodulate=demodulate,
|
359 |
+
dilation=dilation, ##### modified
|
360 |
+
)
|
361 |
+
|
362 |
+
self.noise = NoiseInjection()
|
363 |
+
self.activate = FusedLeakyReLU(out_channel)
|
364 |
+
|
365 |
+
def forward(self, input, style, noise=None):
|
366 |
+
out = self.conv(input, style)
|
367 |
+
out = self.noise(out, noise=noise)
|
368 |
+
out = self.activate(out)
|
369 |
+
|
370 |
+
return out
|
371 |
+
|
372 |
+
|
373 |
+
class ToRGB(nn.Module):
|
374 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1], dilation=1): ##### modified
|
375 |
+
super().__init__()
|
376 |
+
|
377 |
+
if upsample:
|
378 |
+
self.upsample = Upsample(blur_kernel)
|
379 |
+
|
380 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
381 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
382 |
+
|
383 |
+
self.dilation = dilation ##### modified
|
384 |
+
if dilation > 1: ##### modified
|
385 |
+
blur_weight = torch.randn(1, 1, 3, 3) * 0 + 1
|
386 |
+
blur_weight[:,:,0,1] = 2
|
387 |
+
blur_weight[:,:,1,0] = 2
|
388 |
+
blur_weight[:,:,1,2] = 2
|
389 |
+
blur_weight[:,:,2,1] = 2
|
390 |
+
blur_weight[:,:,1,1] = 4
|
391 |
+
blur_weight = blur_weight / 16.0
|
392 |
+
self.register_buffer("blur_weight", blur_weight)
|
393 |
+
|
394 |
+
def forward(self, input, style, skip=None):
|
395 |
+
out = self.conv(input, style)
|
396 |
+
out = out + self.bias
|
397 |
+
|
398 |
+
if skip is not None:
|
399 |
+
if self.dilation == 1:
|
400 |
+
skip = self.upsample(skip)
|
401 |
+
else: ##### modified, to simulate skip = self.upsample(skip)
|
402 |
+
batch, in_channel, _, _ = skip.shape
|
403 |
+
skip = F.conv2d(skip, self.blur_weight.repeat(in_channel,1,1,1),
|
404 |
+
padding=self.dilation//2, groups=in_channel, dilation=self.dilation//2)
|
405 |
+
|
406 |
+
out = out + skip
|
407 |
+
|
408 |
+
return out
|
409 |
+
|
410 |
+
|
411 |
+
class Generator(nn.Module):
|
412 |
+
def __init__(
|
413 |
+
self,
|
414 |
+
size,
|
415 |
+
style_dim,
|
416 |
+
n_mlp,
|
417 |
+
channel_multiplier=2,
|
418 |
+
blur_kernel=[1, 3, 3, 1],
|
419 |
+
lr_mlp=0.01,
|
420 |
+
):
|
421 |
+
super().__init__()
|
422 |
+
|
423 |
+
self.size = size
|
424 |
+
|
425 |
+
self.style_dim = style_dim
|
426 |
+
|
427 |
+
layers = [PixelNorm()]
|
428 |
+
|
429 |
+
for i in range(n_mlp):
|
430 |
+
layers.append(
|
431 |
+
EqualLinear(
|
432 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
433 |
+
)
|
434 |
+
)
|
435 |
+
|
436 |
+
self.style = nn.Sequential(*layers)
|
437 |
+
|
438 |
+
self.channels = {
|
439 |
+
4: 512,
|
440 |
+
8: 512,
|
441 |
+
16: 512,
|
442 |
+
32: 512,
|
443 |
+
64: 256 * channel_multiplier,
|
444 |
+
128: 128 * channel_multiplier,
|
445 |
+
256: 64 * channel_multiplier,
|
446 |
+
512: 32 * channel_multiplier,
|
447 |
+
1024: 16 * channel_multiplier,
|
448 |
+
}
|
449 |
+
|
450 |
+
self.input = ConstantInput(self.channels[4])
|
451 |
+
self.conv1 = StyledConv(
|
452 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel, dilation=8 ##### modified
|
453 |
+
)
|
454 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
455 |
+
|
456 |
+
self.log_size = int(math.log(size, 2))
|
457 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
458 |
+
|
459 |
+
self.convs = nn.ModuleList()
|
460 |
+
self.upsamples = nn.ModuleList()
|
461 |
+
self.to_rgbs = nn.ModuleList()
|
462 |
+
self.noises = nn.Module()
|
463 |
+
|
464 |
+
in_channel = self.channels[4]
|
465 |
+
|
466 |
+
for layer_idx in range(self.num_layers):
|
467 |
+
res = (layer_idx + 5) // 2
|
468 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
469 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
470 |
+
|
471 |
+
for i in range(3, self.log_size + 1):
|
472 |
+
out_channel = self.channels[2 ** i]
|
473 |
+
|
474 |
+
self.convs.append(
|
475 |
+
StyledConv(
|
476 |
+
in_channel,
|
477 |
+
out_channel,
|
478 |
+
3,
|
479 |
+
style_dim,
|
480 |
+
upsample=True,
|
481 |
+
blur_kernel=blur_kernel,
|
482 |
+
dilation=max(1, 32 // (2**(i-1))) ##### modified
|
483 |
+
)
|
484 |
+
)
|
485 |
+
|
486 |
+
self.convs.append(
|
487 |
+
StyledConv(
|
488 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel, dilation=max(1, 32 // (2**i)) ##### modified
|
489 |
+
)
|
490 |
+
)
|
491 |
+
|
492 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim, dilation=max(1, 32 // (2**(i-1))))) ##### modified
|
493 |
+
|
494 |
+
in_channel = out_channel
|
495 |
+
|
496 |
+
self.n_latent = self.log_size * 2 - 2
|
497 |
+
|
498 |
+
def make_noise(self):
|
499 |
+
device = self.input.input.device
|
500 |
+
|
501 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
502 |
+
|
503 |
+
for i in range(3, self.log_size + 1):
|
504 |
+
for _ in range(2):
|
505 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
506 |
+
|
507 |
+
return noises
|
508 |
+
|
509 |
+
def mean_latent(self, n_latent):
|
510 |
+
latent_in = torch.randn(
|
511 |
+
n_latent, self.style_dim, device=self.input.input.device
|
512 |
+
)
|
513 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
514 |
+
|
515 |
+
return latent
|
516 |
+
|
517 |
+
def get_latent(self, input):
|
518 |
+
return self.style(input)
|
519 |
+
|
520 |
+
# styles is the latent code w+
|
521 |
+
# first_layer_feature is the first-layer input feature f
|
522 |
+
# first_layer_feature_ind indicate which layer of G accepts f (should always=0, the first layer)
|
523 |
+
# skip_layer_feature is the encoder features sent by skip connection
|
524 |
+
# fusion_block is the network to fuse the encoder feature and decoder feature
|
525 |
+
# zero_noise is to force the noise to be zero (to avoid flickers for videos)
|
526 |
+
# editing_w is the editing vector v used in video face editing
|
527 |
+
def forward(
|
528 |
+
self,
|
529 |
+
styles,
|
530 |
+
return_latents=False,
|
531 |
+
return_features=False,
|
532 |
+
inject_index=None,
|
533 |
+
truncation=1,
|
534 |
+
truncation_latent=None,
|
535 |
+
input_is_latent=False,
|
536 |
+
noise=None,
|
537 |
+
randomize_noise=True,
|
538 |
+
first_layer_feature = None, ##### modified
|
539 |
+
first_layer_feature_ind = 0, ##### modified
|
540 |
+
skip_layer_feature = None, ##### modified
|
541 |
+
fusion_block = None, ##### modified
|
542 |
+
zero_noise = False, ##### modified
|
543 |
+
editing_w = None, ##### modified
|
544 |
+
):
|
545 |
+
if not input_is_latent:
|
546 |
+
styles = [self.style(s) for s in styles]
|
547 |
+
|
548 |
+
if zero_noise:
|
549 |
+
noise = [
|
550 |
+
getattr(self.noises, f'noise_{i}') * 0.0 for i in range(self.num_layers)
|
551 |
+
]
|
552 |
+
elif noise is None:
|
553 |
+
if randomize_noise:
|
554 |
+
noise = [None] * self.num_layers
|
555 |
+
else:
|
556 |
+
noise = [
|
557 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
558 |
+
]
|
559 |
+
|
560 |
+
if truncation < 1:
|
561 |
+
style_t = []
|
562 |
+
|
563 |
+
for style in styles:
|
564 |
+
style_t.append(
|
565 |
+
truncation_latent + truncation * (style - truncation_latent)
|
566 |
+
)
|
567 |
+
|
568 |
+
styles = style_t
|
569 |
+
|
570 |
+
if len(styles) < 2:
|
571 |
+
inject_index = self.n_latent
|
572 |
+
|
573 |
+
if styles[0].ndim < 3:
|
574 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
575 |
+
else:
|
576 |
+
latent = styles[0]
|
577 |
+
|
578 |
+
else:
|
579 |
+
if inject_index is None:
|
580 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
581 |
+
|
582 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
583 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
584 |
+
|
585 |
+
latent = torch.cat([latent, latent2], 1)
|
586 |
+
|
587 |
+
# w+ + v for video face editing
|
588 |
+
if editing_w is not None: ##### modified
|
589 |
+
latent = latent + editing_w
|
590 |
+
|
591 |
+
# the original StyleGAN
|
592 |
+
if first_layer_feature is None: ##### modified
|
593 |
+
out = self.input(latent)
|
594 |
+
out = F.adaptive_avg_pool2d(out, 32) ##### modified
|
595 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
596 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
597 |
+
# the default StyleGANEX, replacing the first layer of G
|
598 |
+
elif first_layer_feature_ind == 0: ##### modified
|
599 |
+
out = first_layer_feature[0] ##### modified
|
600 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
601 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
602 |
+
# maybe we can also use the second layer of G to accept f?
|
603 |
+
else: ##### modified
|
604 |
+
out = first_layer_feature[0] ##### modified
|
605 |
+
skip = first_layer_feature[1] ##### modified
|
606 |
+
|
607 |
+
i = 1
|
608 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
609 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
610 |
+
):
|
611 |
+
# these layers accepts skipped encoder layer, use fusion block to fuse the encoder feature and decoder feature
|
612 |
+
if skip_layer_feature and fusion_block and i//2 < len(skip_layer_feature) and i//2 < len(fusion_block):
|
613 |
+
if editing_w is None:
|
614 |
+
out, skip = fusion_block[i//2](skip_layer_feature[i//2], out, skip)
|
615 |
+
else:
|
616 |
+
out, skip = fusion_block[i//2](skip_layer_feature[i//2], out, skip, editing_w[:,i])
|
617 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
618 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
619 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
620 |
+
|
621 |
+
i += 2
|
622 |
+
|
623 |
+
image = skip
|
624 |
+
|
625 |
+
if return_latents:
|
626 |
+
return image, latent
|
627 |
+
elif return_features:
|
628 |
+
return image, out
|
629 |
+
else:
|
630 |
+
return image, None
|
631 |
+
|
632 |
+
|
633 |
+
class ConvLayer(nn.Sequential):
|
634 |
+
def __init__(
|
635 |
+
self,
|
636 |
+
in_channel,
|
637 |
+
out_channel,
|
638 |
+
kernel_size,
|
639 |
+
downsample=False,
|
640 |
+
blur_kernel=[1, 3, 3, 1],
|
641 |
+
bias=True,
|
642 |
+
activate=True,
|
643 |
+
dilation=1, ## modified
|
644 |
+
):
|
645 |
+
layers = []
|
646 |
+
|
647 |
+
if downsample:
|
648 |
+
factor = 2
|
649 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
650 |
+
pad0 = (p + 1) // 2
|
651 |
+
pad1 = p // 2
|
652 |
+
|
653 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
654 |
+
|
655 |
+
stride = 2
|
656 |
+
self.padding = 0
|
657 |
+
|
658 |
+
else:
|
659 |
+
stride = 1
|
660 |
+
self.padding = kernel_size // 2 + dilation-1 ## modified
|
661 |
+
|
662 |
+
layers.append(
|
663 |
+
EqualConv2d(
|
664 |
+
in_channel,
|
665 |
+
out_channel,
|
666 |
+
kernel_size,
|
667 |
+
padding=self.padding,
|
668 |
+
stride=stride,
|
669 |
+
bias=bias and not activate,
|
670 |
+
dilation=dilation, ## modified
|
671 |
+
)
|
672 |
+
)
|
673 |
+
|
674 |
+
if activate:
|
675 |
+
if bias:
|
676 |
+
layers.append(FusedLeakyReLU(out_channel))
|
677 |
+
|
678 |
+
else:
|
679 |
+
layers.append(ScaledLeakyReLU(0.2))
|
680 |
+
|
681 |
+
super().__init__(*layers)
|
682 |
+
|
683 |
+
|
684 |
+
class ResBlock(nn.Module):
|
685 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
686 |
+
super().__init__()
|
687 |
+
|
688 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
689 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
690 |
+
|
691 |
+
self.skip = ConvLayer(
|
692 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
693 |
+
)
|
694 |
+
|
695 |
+
def forward(self, input):
|
696 |
+
out = self.conv1(input)
|
697 |
+
out = self.conv2(out)
|
698 |
+
|
699 |
+
skip = self.skip(input)
|
700 |
+
out = (out + skip) / math.sqrt(2)
|
701 |
+
|
702 |
+
return out
|
703 |
+
|
704 |
+
|
705 |
+
class Discriminator(nn.Module):
|
706 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], img_channel=3):
|
707 |
+
super().__init__()
|
708 |
+
|
709 |
+
channels = {
|
710 |
+
4: 512,
|
711 |
+
8: 512,
|
712 |
+
16: 512,
|
713 |
+
32: 512,
|
714 |
+
64: 256 * channel_multiplier,
|
715 |
+
128: 128 * channel_multiplier,
|
716 |
+
256: 64 * channel_multiplier,
|
717 |
+
512: 32 * channel_multiplier,
|
718 |
+
1024: 16 * channel_multiplier,
|
719 |
+
}
|
720 |
+
|
721 |
+
convs = [ConvLayer(img_channel, channels[size], 1)]
|
722 |
+
|
723 |
+
log_size = int(math.log(size, 2))
|
724 |
+
|
725 |
+
in_channel = channels[size]
|
726 |
+
|
727 |
+
for i in range(log_size, 2, -1):
|
728 |
+
out_channel = channels[2 ** (i - 1)]
|
729 |
+
|
730 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
731 |
+
|
732 |
+
in_channel = out_channel
|
733 |
+
|
734 |
+
self.convs = nn.Sequential(*convs)
|
735 |
+
|
736 |
+
self.stddev_group = 4
|
737 |
+
self.stddev_feat = 1
|
738 |
+
|
739 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
740 |
+
self.final_linear = nn.Sequential(
|
741 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
742 |
+
EqualLinear(channels[4], 1),
|
743 |
+
)
|
744 |
+
|
745 |
+
self.size = size ##### modified
|
746 |
+
|
747 |
+
def forward(self, input):
|
748 |
+
# for input that not satisfies the target size, we crop it to extract a small image of the target size.
|
749 |
+
_, _, h, w = input.shape ##### modified
|
750 |
+
i, j = torch.randint(0, h+1-self.size, size=(1,)).item(), torch.randint(0, w+1-self.size, size=(1,)).item() ##### modified
|
751 |
+
out = self.convs(input[:,:,i:i+self.size,j:j+self.size]) ##### modified
|
752 |
+
|
753 |
+
batch, channel, height, width = out.shape
|
754 |
+
group = min(batch, self.stddev_group)
|
755 |
+
stddev = out.view(
|
756 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
757 |
+
)
|
758 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
759 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
760 |
+
stddev = stddev.repeat(group, 1, height, width)
|
761 |
+
out = torch.cat([out, stddev], 1)
|
762 |
+
|
763 |
+
out = self.final_conv(out)
|
764 |
+
|
765 |
+
out = out.view(batch, -1)
|
766 |
+
out = self.final_linear(out)
|
767 |
+
|
768 |
+
return out
|
models/stylegan2/op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
models/stylegan2/op/conv2d_gradfix.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import autograd
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
enabled = True
|
9 |
+
weight_gradients_disabled = False
|
10 |
+
|
11 |
+
|
12 |
+
@contextlib.contextmanager
|
13 |
+
def no_weight_gradients():
|
14 |
+
global weight_gradients_disabled
|
15 |
+
|
16 |
+
old = weight_gradients_disabled
|
17 |
+
weight_gradients_disabled = True
|
18 |
+
yield
|
19 |
+
weight_gradients_disabled = old
|
20 |
+
|
21 |
+
|
22 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
23 |
+
if could_use_op(input):
|
24 |
+
return conv2d_gradfix(
|
25 |
+
transpose=False,
|
26 |
+
weight_shape=weight.shape,
|
27 |
+
stride=stride,
|
28 |
+
padding=padding,
|
29 |
+
output_padding=0,
|
30 |
+
dilation=dilation,
|
31 |
+
groups=groups,
|
32 |
+
).apply(input, weight, bias)
|
33 |
+
|
34 |
+
return F.conv2d(
|
35 |
+
input=input,
|
36 |
+
weight=weight,
|
37 |
+
bias=bias,
|
38 |
+
stride=stride,
|
39 |
+
padding=padding,
|
40 |
+
dilation=dilation,
|
41 |
+
groups=groups,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
def conv_transpose2d(
|
46 |
+
input,
|
47 |
+
weight,
|
48 |
+
bias=None,
|
49 |
+
stride=1,
|
50 |
+
padding=0,
|
51 |
+
output_padding=0,
|
52 |
+
groups=1,
|
53 |
+
dilation=1,
|
54 |
+
):
|
55 |
+
if could_use_op(input):
|
56 |
+
return conv2d_gradfix(
|
57 |
+
transpose=True,
|
58 |
+
weight_shape=weight.shape,
|
59 |
+
stride=stride,
|
60 |
+
padding=padding,
|
61 |
+
output_padding=output_padding,
|
62 |
+
groups=groups,
|
63 |
+
dilation=dilation,
|
64 |
+
).apply(input, weight, bias)
|
65 |
+
|
66 |
+
return F.conv_transpose2d(
|
67 |
+
input=input,
|
68 |
+
weight=weight,
|
69 |
+
bias=bias,
|
70 |
+
stride=stride,
|
71 |
+
padding=padding,
|
72 |
+
output_padding=output_padding,
|
73 |
+
dilation=dilation,
|
74 |
+
groups=groups,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def could_use_op(input):
|
79 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
80 |
+
return False
|
81 |
+
|
82 |
+
if input.device.type != "cuda":
|
83 |
+
return False
|
84 |
+
|
85 |
+
if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
|
86 |
+
return True
|
87 |
+
|
88 |
+
warnings.warn(
|
89 |
+
f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
|
90 |
+
)
|
91 |
+
|
92 |
+
return False
|
93 |
+
|
94 |
+
|
95 |
+
def ensure_tuple(xs, ndim):
|
96 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
97 |
+
|
98 |
+
return xs
|
99 |
+
|
100 |
+
|
101 |
+
conv2d_gradfix_cache = dict()
|
102 |
+
|
103 |
+
|
104 |
+
def conv2d_gradfix(
|
105 |
+
transpose, weight_shape, stride, padding, output_padding, dilation, groups
|
106 |
+
):
|
107 |
+
ndim = 2
|
108 |
+
weight_shape = tuple(weight_shape)
|
109 |
+
stride = ensure_tuple(stride, ndim)
|
110 |
+
padding = ensure_tuple(padding, ndim)
|
111 |
+
output_padding = ensure_tuple(output_padding, ndim)
|
112 |
+
dilation = ensure_tuple(dilation, ndim)
|
113 |
+
|
114 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
115 |
+
if key in conv2d_gradfix_cache:
|
116 |
+
return conv2d_gradfix_cache[key]
|
117 |
+
|
118 |
+
common_kwargs = dict(
|
119 |
+
stride=stride, padding=padding, dilation=dilation, groups=groups
|
120 |
+
)
|
121 |
+
|
122 |
+
def calc_output_padding(input_shape, output_shape):
|
123 |
+
if transpose:
|
124 |
+
return [0, 0]
|
125 |
+
|
126 |
+
return [
|
127 |
+
input_shape[i + 2]
|
128 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
129 |
+
- (1 - 2 * padding[i])
|
130 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
131 |
+
for i in range(ndim)
|
132 |
+
]
|
133 |
+
|
134 |
+
class Conv2d(autograd.Function):
|
135 |
+
@staticmethod
|
136 |
+
def forward(ctx, input, weight, bias):
|
137 |
+
if not transpose:
|
138 |
+
out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
139 |
+
|
140 |
+
else:
|
141 |
+
out = F.conv_transpose2d(
|
142 |
+
input=input,
|
143 |
+
weight=weight,
|
144 |
+
bias=bias,
|
145 |
+
output_padding=output_padding,
|
146 |
+
**common_kwargs,
|
147 |
+
)
|
148 |
+
|
149 |
+
ctx.save_for_backward(input, weight)
|
150 |
+
|
151 |
+
return out
|
152 |
+
|
153 |
+
@staticmethod
|
154 |
+
def backward(ctx, grad_output):
|
155 |
+
input, weight = ctx.saved_tensors
|
156 |
+
grad_input, grad_weight, grad_bias = None, None, None
|
157 |
+
|
158 |
+
if ctx.needs_input_grad[0]:
|
159 |
+
p = calc_output_padding(
|
160 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
161 |
+
)
|
162 |
+
grad_input = conv2d_gradfix(
|
163 |
+
transpose=(not transpose),
|
164 |
+
weight_shape=weight_shape,
|
165 |
+
output_padding=p,
|
166 |
+
**common_kwargs,
|
167 |
+
).apply(grad_output, weight, None)
|
168 |
+
|
169 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
170 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
171 |
+
|
172 |
+
if ctx.needs_input_grad[2]:
|
173 |
+
grad_bias = grad_output.sum((0, 2, 3))
|
174 |
+
|
175 |
+
return grad_input, grad_weight, grad_bias
|
176 |
+
|
177 |
+
class Conv2dGradWeight(autograd.Function):
|
178 |
+
@staticmethod
|
179 |
+
def forward(ctx, grad_output, input):
|
180 |
+
op = torch._C._jit_get_operation(
|
181 |
+
"aten::cudnn_convolution_backward_weight"
|
182 |
+
if not transpose
|
183 |
+
else "aten::cudnn_convolution_transpose_backward_weight"
|
184 |
+
)
|
185 |
+
flags = [
|
186 |
+
torch.backends.cudnn.benchmark,
|
187 |
+
torch.backends.cudnn.deterministic,
|
188 |
+
torch.backends.cudnn.allow_tf32,
|
189 |
+
]
|
190 |
+
grad_weight = op(
|
191 |
+
weight_shape,
|
192 |
+
grad_output,
|
193 |
+
input,
|
194 |
+
padding,
|
195 |
+
stride,
|
196 |
+
dilation,
|
197 |
+
groups,
|
198 |
+
*flags,
|
199 |
+
)
|
200 |
+
ctx.save_for_backward(grad_output, input)
|
201 |
+
|
202 |
+
return grad_weight
|
203 |
+
|
204 |
+
@staticmethod
|
205 |
+
def backward(ctx, grad_grad_weight):
|
206 |
+
grad_output, input = ctx.saved_tensors
|
207 |
+
grad_grad_output, grad_grad_input = None, None
|
208 |
+
|
209 |
+
if ctx.needs_input_grad[0]:
|
210 |
+
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
|
211 |
+
|
212 |
+
if ctx.needs_input_grad[1]:
|
213 |
+
p = calc_output_padding(
|
214 |
+
input_shape=input.shape, output_shape=grad_output.shape
|
215 |
+
)
|
216 |
+
grad_grad_input = conv2d_gradfix(
|
217 |
+
transpose=(not transpose),
|
218 |
+
weight_shape=weight_shape,
|
219 |
+
output_padding=p,
|
220 |
+
**common_kwargs,
|
221 |
+
).apply(grad_output, grad_grad_weight, None)
|
222 |
+
|
223 |
+
return grad_grad_output, grad_grad_input
|
224 |
+
|
225 |
+
conv2d_gradfix_cache[key] = Conv2d
|
226 |
+
|
227 |
+
return Conv2d
|
models/stylegan2/op/fused_act.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class FusedLeakyReLU(nn.Module):
|
7 |
+
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
if bias:
|
11 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
12 |
+
|
13 |
+
else:
|
14 |
+
self.bias = None
|
15 |
+
|
16 |
+
self.negative_slope = negative_slope
|
17 |
+
self.scale = scale
|
18 |
+
|
19 |
+
def forward(self, inputs):
|
20 |
+
return fused_leaky_relu(inputs, self.bias, self.negative_slope, self.scale)
|
21 |
+
|
22 |
+
|
23 |
+
def fused_leaky_relu(inputs, bias=None, negative_slope=0.2, scale=2 ** 0.5):
|
24 |
+
if bias is not None:
|
25 |
+
rest_dim = [1] * (inputs.ndim - bias.ndim - 1)
|
26 |
+
return (
|
27 |
+
F.leaky_relu(
|
28 |
+
inputs + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
|
29 |
+
)
|
30 |
+
* scale
|
31 |
+
)
|
32 |
+
|
33 |
+
else:
|
34 |
+
return F.leaky_relu(inputs, negative_slope=negative_slope) * scale
|
models/stylegan2/op/readme.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Code from [rosinality-stylegan2-pytorch-cp](https://github.com/senior-sigan/rosinality-stylegan2-pytorch-cpu)
|
2 |
+
|
3 |
+
Scripts to convert rosinality/stylegan2-pytorch to the CPU compatible format
|
4 |
+
|
5 |
+
If you would like to use CPU for testing or have a problem regarding the cpp extention (fused and upfirdn2d), please make the following changes:
|
6 |
+
|
7 |
+
Change `model.stylegan.op` to `model.stylegan.op_cpu`
|
8 |
+
https://github.com/williamyang1991/VToonify/blob/01b383efc00007f9b069585db41a7d31a77a8806/util.py#L14
|
9 |
+
|
10 |
+
https://github.com/williamyang1991/VToonify/blob/01b383efc00007f9b069585db41a7d31a77a8806/model/simple_augment.py#L12
|
11 |
+
|
12 |
+
https://github.com/williamyang1991/VToonify/blob/01b383efc00007f9b069585db41a7d31a77a8806/model/stylegan/model.py#L11
|
models/stylegan2/op/upfirdn2d.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import abc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def upfirdn2d(inputs, kernel, up=1, down=1, pad=(0, 0)):
|
8 |
+
if not isinstance(up, abc.Iterable):
|
9 |
+
up = (up, up)
|
10 |
+
|
11 |
+
if not isinstance(down, abc.Iterable):
|
12 |
+
down = (down, down)
|
13 |
+
|
14 |
+
if len(pad) == 2:
|
15 |
+
pad = (pad[0], pad[1], pad[0], pad[1])
|
16 |
+
|
17 |
+
return upfirdn2d_native(inputs, kernel, *up, *down, *pad)
|
18 |
+
|
19 |
+
|
20 |
+
def upfirdn2d_native(
|
21 |
+
inputs, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
22 |
+
):
|
23 |
+
_, channel, in_h, in_w = inputs.shape
|
24 |
+
inputs = inputs.reshape(-1, in_h, in_w, 1)
|
25 |
+
|
26 |
+
_, in_h, in_w, minor = inputs.shape
|
27 |
+
kernel_h, kernel_w = kernel.shape
|
28 |
+
|
29 |
+
out = inputs.view(-1, in_h, 1, in_w, 1, minor)
|
30 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
31 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
32 |
+
|
33 |
+
out = F.pad(
|
34 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
35 |
+
)
|
36 |
+
out = out[
|
37 |
+
:,
|
38 |
+
max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
|
39 |
+
max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
|
40 |
+
:,
|
41 |
+
]
|
42 |
+
|
43 |
+
out = out.permute(0, 3, 1, 2)
|
44 |
+
out = out.reshape(
|
45 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
46 |
+
)
|
47 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
48 |
+
out = F.conv2d(out, w)
|
49 |
+
out = out.reshape(
|
50 |
+
-1,
|
51 |
+
minor,
|
52 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
53 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
54 |
+
)
|
55 |
+
out = out.permute(0, 2, 3, 1)
|
56 |
+
out = out[:, ::down_y, ::down_x, :]
|
57 |
+
|
58 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
|
59 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
|
60 |
+
|
61 |
+
return out.view(-1, channel, out_h, out_w)
|
models/stylegan2/op_ori/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
models/stylegan2/op_ori/fused_act.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.utils.cpp_extension import load
|
7 |
+
|
8 |
+
module_path = os.path.dirname(__file__)
|
9 |
+
fused = load(
|
10 |
+
'fused',
|
11 |
+
sources=[
|
12 |
+
os.path.join(module_path, 'fused_bias_act.cpp'),
|
13 |
+
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
14 |
+
],
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
19 |
+
@staticmethod
|
20 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
21 |
+
ctx.save_for_backward(out)
|
22 |
+
ctx.negative_slope = negative_slope
|
23 |
+
ctx.scale = scale
|
24 |
+
|
25 |
+
empty = grad_output.new_empty(0)
|
26 |
+
|
27 |
+
grad_input = fused.fused_bias_act(
|
28 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
29 |
+
)
|
30 |
+
|
31 |
+
dim = [0]
|
32 |
+
|
33 |
+
if grad_input.ndim > 2:
|
34 |
+
dim += list(range(2, grad_input.ndim))
|
35 |
+
|
36 |
+
grad_bias = grad_input.sum(dim).detach()
|
37 |
+
|
38 |
+
return grad_input, grad_bias
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
42 |
+
out, = ctx.saved_tensors
|
43 |
+
gradgrad_out = fused.fused_bias_act(
|
44 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
45 |
+
)
|
46 |
+
|
47 |
+
return gradgrad_out, None, None, None
|
48 |
+
|
49 |
+
|
50 |
+
class FusedLeakyReLUFunction(Function):
|
51 |
+
@staticmethod
|
52 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
53 |
+
empty = input.new_empty(0)
|
54 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
55 |
+
ctx.save_for_backward(out)
|
56 |
+
ctx.negative_slope = negative_slope
|
57 |
+
ctx.scale = scale
|
58 |
+
|
59 |
+
return out
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def backward(ctx, grad_output):
|
63 |
+
out, = ctx.saved_tensors
|
64 |
+
|
65 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
66 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
67 |
+
)
|
68 |
+
|
69 |
+
return grad_input, grad_bias, None, None
|
70 |
+
|
71 |
+
|
72 |
+
class FusedLeakyReLU(nn.Module):
|
73 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
77 |
+
self.negative_slope = negative_slope
|
78 |
+
self.scale = scale
|
79 |
+
|
80 |
+
def forward(self, input):
|
81 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
82 |
+
|
83 |
+
|
84 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
85 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
models/stylegan2/op_ori/fused_bias_act.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
5 |
+
int act, int grad, float alpha, float scale);
|
6 |
+
|
7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
10 |
+
|
11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
12 |
+
int act, int grad, float alpha, float scale) {
|
13 |
+
CHECK_CUDA(input);
|
14 |
+
CHECK_CUDA(bias);
|
15 |
+
|
16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
17 |
+
}
|
18 |
+
|
19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
21 |
+
}
|
models/stylegan2/op_ori/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
22 |
+
|
23 |
+
scalar_t zero = 0.0;
|
24 |
+
|
25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
26 |
+
scalar_t x = p_x[xi];
|
27 |
+
|
28 |
+
if (use_bias) {
|
29 |
+
x += p_b[(xi / step_b) % size_b];
|
30 |
+
}
|
31 |
+
|
32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
33 |
+
|
34 |
+
scalar_t y;
|
35 |
+
|
36 |
+
switch (act * 10 + grad) {
|
37 |
+
default:
|
38 |
+
case 10: y = x; break;
|
39 |
+
case 11: y = x; break;
|
40 |
+
case 12: y = 0.0; break;
|
41 |
+
|
42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
44 |
+
case 32: y = 0.0; break;
|
45 |
+
}
|
46 |
+
|
47 |
+
out[xi] = y * scale;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
53 |
+
int act, int grad, float alpha, float scale) {
|
54 |
+
int curDevice = -1;
|
55 |
+
cudaGetDevice(&curDevice);
|
56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
57 |
+
|
58 |
+
auto x = input.contiguous();
|
59 |
+
auto b = bias.contiguous();
|
60 |
+
auto ref = refer.contiguous();
|
61 |
+
|
62 |
+
int use_bias = b.numel() ? 1 : 0;
|
63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
64 |
+
|
65 |
+
int size_x = x.numel();
|
66 |
+
int size_b = b.numel();
|
67 |
+
int step_b = 1;
|
68 |
+
|
69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
70 |
+
step_b *= x.size(i);
|
71 |
+
}
|
72 |
+
|
73 |
+
int loop_x = 4;
|
74 |
+
int block_size = 4 * 32;
|
75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
76 |
+
|
77 |
+
auto y = torch::empty_like(x);
|
78 |
+
|
79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
81 |
+
y.data_ptr<scalar_t>(),
|
82 |
+
x.data_ptr<scalar_t>(),
|
83 |
+
b.data_ptr<scalar_t>(),
|
84 |
+
ref.data_ptr<scalar_t>(),
|
85 |
+
act,
|
86 |
+
grad,
|
87 |
+
alpha,
|
88 |
+
scale,
|
89 |
+
loop_x,
|
90 |
+
size_x,
|
91 |
+
step_b,
|
92 |
+
size_b,
|
93 |
+
use_bias,
|
94 |
+
use_ref
|
95 |
+
);
|
96 |
+
});
|
97 |
+
|
98 |
+
return y;
|
99 |
+
}
|
models/stylegan2/op_ori/upfirdn2d.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
5 |
+
int up_x, int up_y, int down_x, int down_y,
|
6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
7 |
+
|
8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
11 |
+
|
12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
13 |
+
int up_x, int up_y, int down_x, int down_y,
|
14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
15 |
+
CHECK_CUDA(input);
|
16 |
+
CHECK_CUDA(kernel);
|
17 |
+
|
18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
19 |
+
}
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
23 |
+
}
|
models/stylegan2/op_ori/upfirdn2d.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.utils.cpp_extension import load
|
6 |
+
|
7 |
+
module_path = os.path.dirname(__file__)
|
8 |
+
upfirdn2d_op = load(
|
9 |
+
'upfirdn2d',
|
10 |
+
sources=[
|
11 |
+
os.path.join(module_path, 'upfirdn2d.cpp'),
|
12 |
+
os.path.join(module_path, 'upfirdn2d_kernel.cu'),
|
13 |
+
],
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class UpFirDn2dBackward(Function):
|
18 |
+
@staticmethod
|
19 |
+
def forward(
|
20 |
+
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
21 |
+
):
|
22 |
+
up_x, up_y = up
|
23 |
+
down_x, down_y = down
|
24 |
+
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
25 |
+
|
26 |
+
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
27 |
+
|
28 |
+
grad_input = upfirdn2d_op.upfirdn2d(
|
29 |
+
grad_output,
|
30 |
+
grad_kernel,
|
31 |
+
down_x,
|
32 |
+
down_y,
|
33 |
+
up_x,
|
34 |
+
up_y,
|
35 |
+
g_pad_x0,
|
36 |
+
g_pad_x1,
|
37 |
+
g_pad_y0,
|
38 |
+
g_pad_y1,
|
39 |
+
)
|
40 |
+
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
41 |
+
|
42 |
+
ctx.save_for_backward(kernel)
|
43 |
+
|
44 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
45 |
+
|
46 |
+
ctx.up_x = up_x
|
47 |
+
ctx.up_y = up_y
|
48 |
+
ctx.down_x = down_x
|
49 |
+
ctx.down_y = down_y
|
50 |
+
ctx.pad_x0 = pad_x0
|
51 |
+
ctx.pad_x1 = pad_x1
|
52 |
+
ctx.pad_y0 = pad_y0
|
53 |
+
ctx.pad_y1 = pad_y1
|
54 |
+
ctx.in_size = in_size
|
55 |
+
ctx.out_size = out_size
|
56 |
+
|
57 |
+
return grad_input
|
58 |
+
|
59 |
+
@staticmethod
|
60 |
+
def backward(ctx, gradgrad_input):
|
61 |
+
kernel, = ctx.saved_tensors
|
62 |
+
|
63 |
+
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
64 |
+
|
65 |
+
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
66 |
+
gradgrad_input,
|
67 |
+
kernel,
|
68 |
+
ctx.up_x,
|
69 |
+
ctx.up_y,
|
70 |
+
ctx.down_x,
|
71 |
+
ctx.down_y,
|
72 |
+
ctx.pad_x0,
|
73 |
+
ctx.pad_x1,
|
74 |
+
ctx.pad_y0,
|
75 |
+
ctx.pad_y1,
|
76 |
+
)
|
77 |
+
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
78 |
+
gradgrad_out = gradgrad_out.view(
|
79 |
+
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
80 |
+
)
|
81 |
+
|
82 |
+
return gradgrad_out, None, None, None, None, None, None, None, None
|
83 |
+
|
84 |
+
|
85 |
+
class UpFirDn2d(Function):
|
86 |
+
@staticmethod
|
87 |
+
def forward(ctx, input, kernel, up, down, pad):
|
88 |
+
up_x, up_y = up
|
89 |
+
down_x, down_y = down
|
90 |
+
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
91 |
+
|
92 |
+
kernel_h, kernel_w = kernel.shape
|
93 |
+
batch, channel, in_h, in_w = input.shape
|
94 |
+
ctx.in_size = input.shape
|
95 |
+
|
96 |
+
input = input.reshape(-1, in_h, in_w, 1)
|
97 |
+
|
98 |
+
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
99 |
+
|
100 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
101 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
102 |
+
ctx.out_size = (out_h, out_w)
|
103 |
+
|
104 |
+
ctx.up = (up_x, up_y)
|
105 |
+
ctx.down = (down_x, down_y)
|
106 |
+
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
107 |
+
|
108 |
+
g_pad_x0 = kernel_w - pad_x0 - 1
|
109 |
+
g_pad_y0 = kernel_h - pad_y0 - 1
|
110 |
+
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
111 |
+
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
112 |
+
|
113 |
+
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
114 |
+
|
115 |
+
out = upfirdn2d_op.upfirdn2d(
|
116 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
117 |
+
)
|
118 |
+
# out = out.view(major, out_h, out_w, minor)
|
119 |
+
out = out.view(-1, channel, out_h, out_w)
|
120 |
+
|
121 |
+
return out
|
122 |
+
|
123 |
+
@staticmethod
|
124 |
+
def backward(ctx, grad_output):
|
125 |
+
kernel, grad_kernel = ctx.saved_tensors
|
126 |
+
|
127 |
+
grad_input = UpFirDn2dBackward.apply(
|
128 |
+
grad_output,
|
129 |
+
kernel,
|
130 |
+
grad_kernel,
|
131 |
+
ctx.up,
|
132 |
+
ctx.down,
|
133 |
+
ctx.pad,
|
134 |
+
ctx.g_pad,
|
135 |
+
ctx.in_size,
|
136 |
+
ctx.out_size,
|
137 |
+
)
|
138 |
+
|
139 |
+
return grad_input, None, None, None, None
|
140 |
+
|
141 |
+
|
142 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
143 |
+
out = UpFirDn2d.apply(
|
144 |
+
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
145 |
+
)
|
146 |
+
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
def upfirdn2d_native(
|
151 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
152 |
+
):
|
153 |
+
_, in_h, in_w, minor = input.shape
|
154 |
+
kernel_h, kernel_w = kernel.shape
|
155 |
+
|
156 |
+
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
157 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
158 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
159 |
+
|
160 |
+
out = F.pad(
|
161 |
+
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
162 |
+
)
|
163 |
+
out = out[
|
164 |
+
:,
|
165 |
+
max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
|
166 |
+
max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
|
167 |
+
:,
|
168 |
+
]
|
169 |
+
|
170 |
+
out = out.permute(0, 3, 1, 2)
|
171 |
+
out = out.reshape(
|
172 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
173 |
+
)
|
174 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
175 |
+
out = F.conv2d(out, w)
|
176 |
+
out = out.reshape(
|
177 |
+
-1,
|
178 |
+
minor,
|
179 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
180 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
181 |
+
)
|
182 |
+
out = out.permute(0, 2, 3, 1)
|
183 |
+
|
184 |
+
return out[:, ::down_y, ::down_x, :]
|
models/stylegan2/op_ori/upfirdn2d_kernel.cu
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
19 |
+
int c = a / b;
|
20 |
+
|
21 |
+
if (c * b > a) {
|
22 |
+
c--;
|
23 |
+
}
|
24 |
+
|
25 |
+
return c;
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
struct UpFirDn2DKernelParams {
|
30 |
+
int up_x;
|
31 |
+
int up_y;
|
32 |
+
int down_x;
|
33 |
+
int down_y;
|
34 |
+
int pad_x0;
|
35 |
+
int pad_x1;
|
36 |
+
int pad_y0;
|
37 |
+
int pad_y1;
|
38 |
+
|
39 |
+
int major_dim;
|
40 |
+
int in_h;
|
41 |
+
int in_w;
|
42 |
+
int minor_dim;
|
43 |
+
int kernel_h;
|
44 |
+
int kernel_w;
|
45 |
+
int out_h;
|
46 |
+
int out_w;
|
47 |
+
int loop_major;
|
48 |
+
int loop_x;
|
49 |
+
};
|
50 |
+
|
51 |
+
|
52 |
+
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y, int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
53 |
+
__global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
|
54 |
+
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
55 |
+
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
56 |
+
|
57 |
+
__shared__ volatile float sk[kernel_h][kernel_w];
|
58 |
+
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
59 |
+
|
60 |
+
int minor_idx = blockIdx.x;
|
61 |
+
int tile_out_y = minor_idx / p.minor_dim;
|
62 |
+
minor_idx -= tile_out_y * p.minor_dim;
|
63 |
+
tile_out_y *= tile_out_h;
|
64 |
+
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
65 |
+
int major_idx_base = blockIdx.z * p.loop_major;
|
66 |
+
|
67 |
+
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
|
68 |
+
return;
|
69 |
+
}
|
70 |
+
|
71 |
+
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
|
72 |
+
int ky = tap_idx / kernel_w;
|
73 |
+
int kx = tap_idx - ky * kernel_w;
|
74 |
+
scalar_t v = 0.0;
|
75 |
+
|
76 |
+
if (kx < p.kernel_w & ky < p.kernel_h) {
|
77 |
+
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
78 |
+
}
|
79 |
+
|
80 |
+
sk[ky][kx] = v;
|
81 |
+
}
|
82 |
+
|
83 |
+
for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
|
84 |
+
for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
|
85 |
+
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
86 |
+
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
87 |
+
int tile_in_x = floor_div(tile_mid_x, up_x);
|
88 |
+
int tile_in_y = floor_div(tile_mid_y, up_y);
|
89 |
+
|
90 |
+
__syncthreads();
|
91 |
+
|
92 |
+
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
|
93 |
+
int rel_in_y = in_idx / tile_in_w;
|
94 |
+
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
95 |
+
int in_x = rel_in_x + tile_in_x;
|
96 |
+
int in_y = rel_in_y + tile_in_y;
|
97 |
+
|
98 |
+
scalar_t v = 0.0;
|
99 |
+
|
100 |
+
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
101 |
+
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
|
102 |
+
}
|
103 |
+
|
104 |
+
sx[rel_in_y][rel_in_x] = v;
|
105 |
+
}
|
106 |
+
|
107 |
+
__syncthreads();
|
108 |
+
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
|
109 |
+
int rel_out_y = out_idx / tile_out_w;
|
110 |
+
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
111 |
+
int out_x = rel_out_x + tile_out_x;
|
112 |
+
int out_y = rel_out_y + tile_out_y;
|
113 |
+
|
114 |
+
int mid_x = tile_mid_x + rel_out_x * down_x;
|
115 |
+
int mid_y = tile_mid_y + rel_out_y * down_y;
|
116 |
+
int in_x = floor_div(mid_x, up_x);
|
117 |
+
int in_y = floor_div(mid_y, up_y);
|
118 |
+
int rel_in_x = in_x - tile_in_x;
|
119 |
+
int rel_in_y = in_y - tile_in_y;
|
120 |
+
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
121 |
+
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
122 |
+
|
123 |
+
scalar_t v = 0.0;
|
124 |
+
|
125 |
+
#pragma unroll
|
126 |
+
for (int y = 0; y < kernel_h / up_y; y++)
|
127 |
+
#pragma unroll
|
128 |
+
for (int x = 0; x < kernel_w / up_x; x++)
|
129 |
+
v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
130 |
+
|
131 |
+
if (out_x < p.out_w & out_y < p.out_h) {
|
132 |
+
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
|
133 |
+
}
|
134 |
+
}
|
135 |
+
}
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
141 |
+
int up_x, int up_y, int down_x, int down_y,
|
142 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
143 |
+
int curDevice = -1;
|
144 |
+
cudaGetDevice(&curDevice);
|
145 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
146 |
+
|
147 |
+
UpFirDn2DKernelParams p;
|
148 |
+
|
149 |
+
auto x = input.contiguous();
|
150 |
+
auto k = kernel.contiguous();
|
151 |
+
|
152 |
+
p.major_dim = x.size(0);
|
153 |
+
p.in_h = x.size(1);
|
154 |
+
p.in_w = x.size(2);
|
155 |
+
p.minor_dim = x.size(3);
|
156 |
+
p.kernel_h = k.size(0);
|
157 |
+
p.kernel_w = k.size(1);
|
158 |
+
p.up_x = up_x;
|
159 |
+
p.up_y = up_y;
|
160 |
+
p.down_x = down_x;
|
161 |
+
p.down_y = down_y;
|
162 |
+
p.pad_x0 = pad_x0;
|
163 |
+
p.pad_x1 = pad_x1;
|
164 |
+
p.pad_y0 = pad_y0;
|
165 |
+
p.pad_y1 = pad_y1;
|
166 |
+
|
167 |
+
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
|
168 |
+
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
|
169 |
+
|
170 |
+
auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
171 |
+
|
172 |
+
int mode = -1;
|
173 |
+
|
174 |
+
int tile_out_h;
|
175 |
+
int tile_out_w;
|
176 |
+
|
177 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
178 |
+
mode = 1;
|
179 |
+
tile_out_h = 16;
|
180 |
+
tile_out_w = 64;
|
181 |
+
}
|
182 |
+
|
183 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
|
184 |
+
mode = 2;
|
185 |
+
tile_out_h = 16;
|
186 |
+
tile_out_w = 64;
|
187 |
+
}
|
188 |
+
|
189 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
190 |
+
mode = 3;
|
191 |
+
tile_out_h = 16;
|
192 |
+
tile_out_w = 64;
|
193 |
+
}
|
194 |
+
|
195 |
+
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
196 |
+
mode = 4;
|
197 |
+
tile_out_h = 16;
|
198 |
+
tile_out_w = 64;
|
199 |
+
}
|
200 |
+
|
201 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
|
202 |
+
mode = 5;
|
203 |
+
tile_out_h = 8;
|
204 |
+
tile_out_w = 32;
|
205 |
+
}
|
206 |
+
|
207 |
+
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
|
208 |
+
mode = 6;
|
209 |
+
tile_out_h = 8;
|
210 |
+
tile_out_w = 32;
|
211 |
+
}
|
212 |
+
|
213 |
+
dim3 block_size;
|
214 |
+
dim3 grid_size;
|
215 |
+
|
216 |
+
if (tile_out_h > 0 && tile_out_w) {
|
217 |
+
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
218 |
+
p.loop_x = 1;
|
219 |
+
block_size = dim3(32 * 8, 1, 1);
|
220 |
+
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
221 |
+
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
222 |
+
(p.major_dim - 1) / p.loop_major + 1);
|
223 |
+
}
|
224 |
+
|
225 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
226 |
+
switch (mode) {
|
227 |
+
case 1:
|
228 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
229 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
230 |
+
);
|
231 |
+
|
232 |
+
break;
|
233 |
+
|
234 |
+
case 2:
|
235 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
236 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
237 |
+
);
|
238 |
+
|
239 |
+
break;
|
240 |
+
|
241 |
+
case 3:
|
242 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
243 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
244 |
+
);
|
245 |
+
|
246 |
+
break;
|
247 |
+
|
248 |
+
case 4:
|
249 |
+
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64><<<grid_size, block_size, 0, stream>>>(
|
250 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
251 |
+
);
|
252 |
+
|
253 |
+
break;
|
254 |
+
|
255 |
+
case 5:
|
256 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
257 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
258 |
+
);
|
259 |
+
|
260 |
+
break;
|
261 |
+
|
262 |
+
case 6:
|
263 |
+
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32><<<grid_size, block_size, 0, stream>>>(
|
264 |
+
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), p
|
265 |
+
);
|
266 |
+
|
267 |
+
break;
|
268 |
+
}
|
269 |
+
});
|
270 |
+
|
271 |
+
return out;
|
272 |
+
}
|
models/stylegan2/simple_augment.py
ADDED
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import autograd
|
5 |
+
from torch.nn import functional as F
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from torch import distributed as dist
|
9 |
+
#from distributed import reduce_sum
|
10 |
+
from models.stylegan2.op2 import upfirdn2d
|
11 |
+
|
12 |
+
def reduce_sum(tensor):
|
13 |
+
if not dist.is_available():
|
14 |
+
return tensor
|
15 |
+
|
16 |
+
if not dist.is_initialized():
|
17 |
+
return tensor
|
18 |
+
|
19 |
+
tensor = tensor.clone()
|
20 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
21 |
+
|
22 |
+
return tensor
|
23 |
+
|
24 |
+
|
25 |
+
class AdaptiveAugment:
|
26 |
+
def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
|
27 |
+
self.ada_aug_target = ada_aug_target
|
28 |
+
self.ada_aug_len = ada_aug_len
|
29 |
+
self.update_every = update_every
|
30 |
+
|
31 |
+
self.ada_update = 0
|
32 |
+
self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
|
33 |
+
self.r_t_stat = 0
|
34 |
+
self.ada_aug_p = 0
|
35 |
+
|
36 |
+
@torch.no_grad()
|
37 |
+
def tune(self, real_pred):
|
38 |
+
self.ada_aug_buf += torch.tensor(
|
39 |
+
(torch.sign(real_pred).sum().item(), real_pred.shape[0]),
|
40 |
+
device=real_pred.device,
|
41 |
+
)
|
42 |
+
self.ada_update += 1
|
43 |
+
|
44 |
+
if self.ada_update % self.update_every == 0:
|
45 |
+
self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
|
46 |
+
pred_signs, n_pred = self.ada_aug_buf.tolist()
|
47 |
+
|
48 |
+
self.r_t_stat = pred_signs / n_pred
|
49 |
+
|
50 |
+
if self.r_t_stat > self.ada_aug_target:
|
51 |
+
sign = 1
|
52 |
+
|
53 |
+
else:
|
54 |
+
sign = -1
|
55 |
+
|
56 |
+
self.ada_aug_p += sign * n_pred / self.ada_aug_len
|
57 |
+
self.ada_aug_p = min(1, max(0, self.ada_aug_p))
|
58 |
+
self.ada_aug_buf.mul_(0)
|
59 |
+
self.ada_update = 0
|
60 |
+
|
61 |
+
return self.ada_aug_p
|
62 |
+
|
63 |
+
|
64 |
+
SYM6 = (
|
65 |
+
0.015404109327027373,
|
66 |
+
0.0034907120842174702,
|
67 |
+
-0.11799011114819057,
|
68 |
+
-0.048311742585633,
|
69 |
+
0.4910559419267466,
|
70 |
+
0.787641141030194,
|
71 |
+
0.3379294217276218,
|
72 |
+
-0.07263752278646252,
|
73 |
+
-0.021060292512300564,
|
74 |
+
0.04472490177066578,
|
75 |
+
0.0017677118642428036,
|
76 |
+
-0.007800708325034148,
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
def translate_mat(t_x, t_y, device="cpu"):
|
81 |
+
batch = t_x.shape[0]
|
82 |
+
|
83 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
84 |
+
translate = torch.stack((t_x, t_y), 1)
|
85 |
+
mat[:, :2, 2] = translate
|
86 |
+
|
87 |
+
return mat
|
88 |
+
|
89 |
+
|
90 |
+
def rotate_mat(theta, device="cpu"):
|
91 |
+
batch = theta.shape[0]
|
92 |
+
|
93 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
94 |
+
sin_t = torch.sin(theta)
|
95 |
+
cos_t = torch.cos(theta)
|
96 |
+
rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
|
97 |
+
mat[:, :2, :2] = rot
|
98 |
+
|
99 |
+
return mat
|
100 |
+
|
101 |
+
|
102 |
+
def scale_mat(s_x, s_y, device="cpu"):
|
103 |
+
batch = s_x.shape[0]
|
104 |
+
|
105 |
+
mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
|
106 |
+
mat[:, 0, 0] = s_x
|
107 |
+
mat[:, 1, 1] = s_y
|
108 |
+
|
109 |
+
return mat
|
110 |
+
|
111 |
+
|
112 |
+
def translate3d_mat(t_x, t_y, t_z):
|
113 |
+
batch = t_x.shape[0]
|
114 |
+
|
115 |
+
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
116 |
+
translate = torch.stack((t_x, t_y, t_z), 1)
|
117 |
+
mat[:, :3, 3] = translate
|
118 |
+
|
119 |
+
return mat
|
120 |
+
|
121 |
+
|
122 |
+
def rotate3d_mat(axis, theta):
|
123 |
+
batch = theta.shape[0]
|
124 |
+
|
125 |
+
u_x, u_y, u_z = axis
|
126 |
+
|
127 |
+
eye = torch.eye(3).unsqueeze(0)
|
128 |
+
cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
|
129 |
+
outer = torch.tensor(axis)
|
130 |
+
outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
|
131 |
+
|
132 |
+
sin_t = torch.sin(theta).view(-1, 1, 1)
|
133 |
+
cos_t = torch.cos(theta).view(-1, 1, 1)
|
134 |
+
|
135 |
+
rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
|
136 |
+
|
137 |
+
eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
138 |
+
eye_4[:, :3, :3] = rot
|
139 |
+
|
140 |
+
return eye_4
|
141 |
+
|
142 |
+
|
143 |
+
def scale3d_mat(s_x, s_y, s_z):
|
144 |
+
batch = s_x.shape[0]
|
145 |
+
|
146 |
+
mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
147 |
+
mat[:, 0, 0] = s_x
|
148 |
+
mat[:, 1, 1] = s_y
|
149 |
+
mat[:, 2, 2] = s_z
|
150 |
+
|
151 |
+
return mat
|
152 |
+
|
153 |
+
|
154 |
+
def luma_flip_mat(axis, i):
|
155 |
+
batch = i.shape[0]
|
156 |
+
|
157 |
+
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
158 |
+
axis = torch.tensor(axis + (0,))
|
159 |
+
flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
|
160 |
+
|
161 |
+
return eye - flip
|
162 |
+
|
163 |
+
|
164 |
+
def saturation_mat(axis, i):
|
165 |
+
batch = i.shape[0]
|
166 |
+
|
167 |
+
eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
|
168 |
+
axis = torch.tensor(axis + (0,))
|
169 |
+
axis = torch.ger(axis, axis)
|
170 |
+
saturate = axis + (eye - axis) * i.view(-1, 1, 1)
|
171 |
+
|
172 |
+
return saturate
|
173 |
+
|
174 |
+
|
175 |
+
def lognormal_sample(size, mean=0, std=1, device="cpu"):
|
176 |
+
return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
|
177 |
+
|
178 |
+
|
179 |
+
def category_sample(size, categories, device="cpu"):
|
180 |
+
category = torch.tensor(categories, device=device)
|
181 |
+
sample = torch.randint(high=len(categories), size=(size,), device=device)
|
182 |
+
|
183 |
+
return category[sample]
|
184 |
+
|
185 |
+
|
186 |
+
def uniform_sample(size, low, high, device="cpu"):
|
187 |
+
return torch.empty(size, device=device).uniform_(low, high)
|
188 |
+
|
189 |
+
|
190 |
+
def normal_sample(size, mean=0, std=1, device="cpu"):
|
191 |
+
return torch.empty(size, device=device).normal_(mean, std)
|
192 |
+
|
193 |
+
|
194 |
+
def bernoulli_sample(size, p, device="cpu"):
|
195 |
+
return torch.empty(size, device=device).bernoulli_(p)
|
196 |
+
|
197 |
+
|
198 |
+
def random_mat_apply(p, transform, prev, eye, device="cpu"):
|
199 |
+
size = transform.shape[0]
|
200 |
+
select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
|
201 |
+
select_transform = select * transform + (1 - select) * eye
|
202 |
+
|
203 |
+
return select_transform @ prev
|
204 |
+
|
205 |
+
|
206 |
+
def sample_affine(p, size, height, width, device="cpu"):
|
207 |
+
G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
|
208 |
+
eye = G
|
209 |
+
|
210 |
+
# flip
|
211 |
+
#param = category_sample(size, (0, 1))
|
212 |
+
#Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
|
213 |
+
#G = random_mat_apply(p, Gc, G, eye, device=device)
|
214 |
+
# print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
|
215 |
+
|
216 |
+
# 90 rotate
|
217 |
+
#param = category_sample(size, (0, 3))
|
218 |
+
#Gc = rotate_mat(-math.pi / 2 * param, device=device)
|
219 |
+
#G = random_mat_apply(p, Gc, G, eye, device=device)
|
220 |
+
# print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
|
221 |
+
|
222 |
+
# integer translate
|
223 |
+
param = uniform_sample(size, -0.125, 0.125)
|
224 |
+
param_height = torch.round(param * height) / height
|
225 |
+
param_width = torch.round(param * width) / width
|
226 |
+
Gc = translate_mat(param_width, param_height, device=device)
|
227 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
228 |
+
# print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
|
229 |
+
|
230 |
+
# isotropic scale
|
231 |
+
param = lognormal_sample(size, std=0.1 * math.log(2))
|
232 |
+
Gc = scale_mat(param, param, device=device)
|
233 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
234 |
+
# print('isotropic scale', G, scale_mat(param, param), sep='\n')
|
235 |
+
|
236 |
+
p_rot = 1 - math.sqrt(1 - p)
|
237 |
+
|
238 |
+
# pre-rotate
|
239 |
+
param = uniform_sample(size, -math.pi * 0.25, math.pi * 0.25)
|
240 |
+
Gc = rotate_mat(-param, device=device)
|
241 |
+
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
242 |
+
# print('pre-rotate', G, rotate_mat(-param), sep='\n')
|
243 |
+
|
244 |
+
# anisotropic scale
|
245 |
+
param = lognormal_sample(size, std=0.1 * math.log(2))
|
246 |
+
Gc = scale_mat(param, 1 / param, device=device)
|
247 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
248 |
+
# print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
|
249 |
+
|
250 |
+
# post-rotate
|
251 |
+
param = uniform_sample(size, -math.pi * 0.25, math.pi * 0.25)
|
252 |
+
Gc = rotate_mat(-param, device=device)
|
253 |
+
G = random_mat_apply(p_rot, Gc, G, eye, device=device)
|
254 |
+
# print('post-rotate', G, rotate_mat(-param), sep='\n')
|
255 |
+
|
256 |
+
# fractional translate
|
257 |
+
param = normal_sample(size, std=0.125)
|
258 |
+
Gc = translate_mat(param, param, device=device)
|
259 |
+
G = random_mat_apply(p, Gc, G, eye, device=device)
|
260 |
+
# print('fractional translate', G, translate_mat(param, param), sep='\n')
|
261 |
+
|
262 |
+
return G
|
263 |
+
|
264 |
+
|
265 |
+
def sample_color(p, size):
|
266 |
+
C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
|
267 |
+
eye = C
|
268 |
+
axis_val = 1 / math.sqrt(3)
|
269 |
+
axis = (axis_val, axis_val, axis_val)
|
270 |
+
|
271 |
+
# brightness
|
272 |
+
param = normal_sample(size, std=0.2)
|
273 |
+
Cc = translate3d_mat(param, param, param)
|
274 |
+
C = random_mat_apply(p, Cc, C, eye)
|
275 |
+
|
276 |
+
# contrast
|
277 |
+
param = lognormal_sample(size, std=0.5 * math.log(2))
|
278 |
+
Cc = scale3d_mat(param, param, param)
|
279 |
+
C = random_mat_apply(p, Cc, C, eye)
|
280 |
+
|
281 |
+
# luma flip
|
282 |
+
param = category_sample(size, (0, 1))
|
283 |
+
Cc = luma_flip_mat(axis, param)
|
284 |
+
C = random_mat_apply(p, Cc, C, eye)
|
285 |
+
|
286 |
+
# hue rotation
|
287 |
+
param = uniform_sample(size, -math.pi, math.pi)
|
288 |
+
Cc = rotate3d_mat(axis, param)
|
289 |
+
C = random_mat_apply(p, Cc, C, eye)
|
290 |
+
|
291 |
+
# saturation
|
292 |
+
param = lognormal_sample(size, std=1 * math.log(2))
|
293 |
+
Cc = saturation_mat(axis, param)
|
294 |
+
C = random_mat_apply(p, Cc, C, eye)
|
295 |
+
|
296 |
+
return C
|
297 |
+
|
298 |
+
|
299 |
+
def make_grid(shape, x0, x1, y0, y1, device):
|
300 |
+
n, c, h, w = shape
|
301 |
+
grid = torch.empty(n, h, w, 3, device=device)
|
302 |
+
grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
|
303 |
+
grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
|
304 |
+
grid[:, :, :, 2] = 1
|
305 |
+
|
306 |
+
return grid
|
307 |
+
|
308 |
+
|
309 |
+
def affine_grid(grid, mat):
|
310 |
+
n, h, w, _ = grid.shape
|
311 |
+
return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
|
312 |
+
|
313 |
+
|
314 |
+
def get_padding(G, height, width, kernel_size):
|
315 |
+
device = G.device
|
316 |
+
|
317 |
+
cx = (width - 1) / 2
|
318 |
+
cy = (height - 1) / 2
|
319 |
+
cp = torch.tensor(
|
320 |
+
[(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
|
321 |
+
)
|
322 |
+
cp = G @ cp.T
|
323 |
+
|
324 |
+
pad_k = kernel_size // 4
|
325 |
+
|
326 |
+
pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
|
327 |
+
pad = torch.cat((-pad, pad)).max(1).values
|
328 |
+
pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
|
329 |
+
pad = pad.max(torch.tensor([0, 0] * 2, device=device))
|
330 |
+
pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
|
331 |
+
|
332 |
+
pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
|
333 |
+
|
334 |
+
return pad_x1, pad_x2, pad_y1, pad_y2
|
335 |
+
|
336 |
+
|
337 |
+
def try_sample_affine_and_pad(img, p, kernel_size, G=None):
|
338 |
+
batch, _, height, width = img.shape
|
339 |
+
|
340 |
+
G_try = G
|
341 |
+
|
342 |
+
if G is None:
|
343 |
+
G_try = torch.inverse(sample_affine(p, batch, height, width))
|
344 |
+
|
345 |
+
pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
|
346 |
+
|
347 |
+
img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
|
348 |
+
|
349 |
+
return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
|
350 |
+
|
351 |
+
|
352 |
+
class GridSampleForward(autograd.Function):
|
353 |
+
@staticmethod
|
354 |
+
def forward(ctx, input, grid):
|
355 |
+
out = F.grid_sample(
|
356 |
+
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
|
357 |
+
)
|
358 |
+
ctx.save_for_backward(input, grid)
|
359 |
+
|
360 |
+
return out
|
361 |
+
|
362 |
+
@staticmethod
|
363 |
+
def backward(ctx, grad_output):
|
364 |
+
input, grid = ctx.saved_tensors
|
365 |
+
grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
|
366 |
+
|
367 |
+
return grad_input, grad_grid
|
368 |
+
|
369 |
+
|
370 |
+
class GridSampleBackward(autograd.Function):
|
371 |
+
@staticmethod
|
372 |
+
def forward(ctx, grad_output, input, grid):
|
373 |
+
op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
|
374 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
375 |
+
ctx.save_for_backward(grid)
|
376 |
+
|
377 |
+
return grad_input, grad_grid
|
378 |
+
|
379 |
+
@staticmethod
|
380 |
+
def backward(ctx, grad_grad_input, grad_grad_grid):
|
381 |
+
grid, = ctx.saved_tensors
|
382 |
+
grad_grad_output = None
|
383 |
+
|
384 |
+
if ctx.needs_input_grad[0]:
|
385 |
+
grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
|
386 |
+
|
387 |
+
return grad_grad_output, None, None
|
388 |
+
|
389 |
+
|
390 |
+
grid_sample = GridSampleForward.apply
|
391 |
+
|
392 |
+
|
393 |
+
def scale_mat_single(s_x, s_y):
|
394 |
+
return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
|
395 |
+
|
396 |
+
|
397 |
+
def translate_mat_single(t_x, t_y):
|
398 |
+
return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
|
399 |
+
|
400 |
+
|
401 |
+
def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
|
402 |
+
kernel = antialiasing_kernel
|
403 |
+
len_k = len(kernel)
|
404 |
+
|
405 |
+
kernel = torch.as_tensor(kernel).to(img)
|
406 |
+
# kernel = torch.ger(kernel, kernel).to(img)
|
407 |
+
kernel_flip = torch.flip(kernel, (0,))
|
408 |
+
|
409 |
+
img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
|
410 |
+
img, p, len_k, G
|
411 |
+
)
|
412 |
+
|
413 |
+
G_inv = (
|
414 |
+
translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
|
415 |
+
@ G
|
416 |
+
)
|
417 |
+
up_pad = (
|
418 |
+
(len_k + 2 - 1) // 2,
|
419 |
+
(len_k - 2) // 2,
|
420 |
+
(len_k + 2 - 1) // 2,
|
421 |
+
(len_k - 2) // 2,
|
422 |
+
)
|
423 |
+
img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
|
424 |
+
img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
|
425 |
+
G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
|
426 |
+
G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
|
427 |
+
batch_size, channel, height, width = img.shape
|
428 |
+
pad_k = len_k // 4
|
429 |
+
shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
|
430 |
+
G_inv = (
|
431 |
+
scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
|
432 |
+
@ G_inv
|
433 |
+
@ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
|
434 |
+
)
|
435 |
+
grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
|
436 |
+
img_affine = grid_sample(img_2x, grid)
|
437 |
+
d_p = -pad_k * 2
|
438 |
+
down_pad = (
|
439 |
+
d_p + (len_k - 2 + 1) // 2,
|
440 |
+
d_p + (len_k - 2) // 2,
|
441 |
+
d_p + (len_k - 2 + 1) // 2,
|
442 |
+
d_p + (len_k - 2) // 2,
|
443 |
+
)
|
444 |
+
img_down = upfirdn2d(
|
445 |
+
img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
|
446 |
+
)
|
447 |
+
img_down = upfirdn2d(
|
448 |
+
img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
|
449 |
+
)
|
450 |
+
|
451 |
+
return img_down, G
|
452 |
+
|
453 |
+
|
454 |
+
def apply_color(img, mat):
|
455 |
+
batch = img.shape[0]
|
456 |
+
img = img.permute(0, 2, 3, 1)
|
457 |
+
mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
|
458 |
+
mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
|
459 |
+
img = img @ mat_mul + mat_add
|
460 |
+
img = img.permute(0, 3, 1, 2)
|
461 |
+
|
462 |
+
return img
|
463 |
+
|
464 |
+
|
465 |
+
def random_apply_color(img, p, C=None):
|
466 |
+
if C is None:
|
467 |
+
C = sample_color(p, img.shape[0])
|
468 |
+
|
469 |
+
img = apply_color(img, C.to(img))
|
470 |
+
|
471 |
+
return img, C
|
472 |
+
|
473 |
+
|
474 |
+
def augment(img, p, transform_matrix=(None, None)):
|
475 |
+
img, G = random_apply_affine(img, p, transform_matrix[0])
|
476 |
+
img, C = random_apply_color(img, p, transform_matrix[1])
|
477 |
+
|
478 |
+
return img, (G, C)
|