gavinyuan commited on
Commit
a104d3f
1 Parent(s): 523fb10

udpate: app.py import FSGenerator

Browse files
app.py CHANGED
@@ -14,7 +14,7 @@ import numpy as np
14
  from PIL import Image
15
  import tqdm
16
 
17
- # from modules.networks.faceshifter import FSGenerator
18
  # from inference.alignment import norm_crop, norm_crop_with_M, paste_back
19
  # from inference.utils import save, get_5_from_98, get_detector, get_lmk
20
  # from inference.PIPNet.lib.tools import get_lmk_model, demo_image
 
14
  from PIL import Image
15
  import tqdm
16
 
17
+ from modules.networks.faceshifter import FSGenerator
18
  # from inference.alignment import norm_crop, norm_crop_with_M, paste_back
19
  # from inference.utils import save, get_5_from_98, get_detector, get_lmk
20
  # from inference.PIPNet.lib.tools import get_lmk_model, demo_image
modules/layers/discriminator.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision
5
+
6
+
7
+ def weights_init(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ m.weight.data.normal_(0.0, 0.02)
11
+ elif classname.find("BatchNorm2d") != -1:
12
+ m.weight.data.normal_(1.0, 0.02)
13
+ m.bias.data.fill_(0)
14
+
15
+
16
+ class MultiscaleDiscriminator(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_nc,
20
+ ndf=64,
21
+ n_layers=3,
22
+ norm_layer=nn.BatchNorm2d,
23
+ use_sigmoid=False,
24
+ num_D=3,
25
+ getIntermFeat=False,
26
+ finetune=False,
27
+ ):
28
+ super(MultiscaleDiscriminator, self).__init__()
29
+ self.num_D = num_D
30
+ self.n_layers = n_layers
31
+ self.getIntermFeat = getIntermFeat
32
+
33
+ for i in range(num_D):
34
+ netD = NLayerDiscriminator(
35
+ input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat
36
+ )
37
+ if getIntermFeat:
38
+ for j in range(n_layers + 2):
39
+ setattr(
40
+ self,
41
+ "scale" + str(i) + "_layer" + str(j),
42
+ getattr(netD, "model" + str(j)),
43
+ )
44
+ else:
45
+ setattr(self, "layer" + str(i), netD.model)
46
+
47
+ self.downsample = nn.AvgPool2d(
48
+ 3, stride=2, padding=[1, 1], count_include_pad=False
49
+ )
50
+ weights_init(self)
51
+
52
+ if finetune:
53
+ self.requires_grad_(False)
54
+ for name, param in self.named_parameters():
55
+ if 'layer0' in name:
56
+ param.requires_grad = True
57
+
58
+ def singleD_forward(self, model, input):
59
+ if self.getIntermFeat:
60
+ result = [input]
61
+ for i in range(len(model)):
62
+ result.append(model[i](result[-1]))
63
+ return result[1:]
64
+ else:
65
+ return [model(input)]
66
+
67
+ def forward(self, input):
68
+ num_D = self.num_D
69
+ result = []
70
+ input_downsampled = input
71
+ for i in range(num_D):
72
+ if self.getIntermFeat:
73
+ model = [
74
+ getattr(self, "scale" + str(num_D - 1 - i) + "_layer" + str(j))
75
+ for j in range(self.n_layers + 2)
76
+ ]
77
+ else:
78
+ model = getattr(self, "layer" + str(num_D - 1 - i))
79
+ result.append(self.singleD_forward(model, input_downsampled))
80
+ if i != (num_D - 1):
81
+ input_downsampled = self.downsample(input_downsampled)
82
+ return result
83
+
84
+
85
+ # Defines the PatchGAN discriminator with the specified arguments.
86
+ class NLayerDiscriminator(nn.Module):
87
+ def __init__(
88
+ self,
89
+ input_nc,
90
+ ndf=64,
91
+ n_layers=3,
92
+ norm_layer=nn.BatchNorm2d,
93
+ use_sigmoid=False,
94
+ getIntermFeat=False,
95
+ ):
96
+ super(NLayerDiscriminator, self).__init__()
97
+ self.getIntermFeat = getIntermFeat
98
+ self.n_layers = n_layers
99
+
100
+ kw = 4
101
+ padw = int(np.ceil((kw - 1.0) / 2))
102
+ sequence = [
103
+ [
104
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
105
+ nn.LeakyReLU(0.2, True),
106
+ ]
107
+ ]
108
+
109
+ nf = ndf
110
+ for n in range(1, n_layers):
111
+ nf_prev = nf
112
+ nf = min(nf * 2, 512)
113
+ sequence += [
114
+ [
115
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
116
+ norm_layer(nf),
117
+ nn.LeakyReLU(0.2, True),
118
+ ]
119
+ ]
120
+
121
+ nf_prev = nf
122
+ nf = min(nf * 2, 512)
123
+ sequence += [
124
+ [
125
+ nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
126
+ norm_layer(nf),
127
+ nn.LeakyReLU(0.2, True),
128
+ ]
129
+ ]
130
+
131
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
132
+
133
+ if use_sigmoid:
134
+ sequence += [[nn.Sigmoid()]]
135
+
136
+ if getIntermFeat:
137
+ for n in range(len(sequence)):
138
+ setattr(self, "model" + str(n), nn.Sequential(*sequence[n]))
139
+ else:
140
+ sequence_stream = []
141
+ for n in range(len(sequence)):
142
+ sequence_stream += sequence[n]
143
+ self.model = nn.Sequential(*sequence_stream)
144
+
145
+ def forward(self, input):
146
+ if self.getIntermFeat:
147
+ res = [input]
148
+ for n in range(self.n_layers + 2):
149
+ model = getattr(self, "model" + str(n))
150
+ res.append(model(res[-1]))
151
+ return res[1:]
152
+ else:
153
+ return self.model(input)
modules/layers/faceshifter/hear_layers.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ def conv4x4(in_c, out_c):
5
+ return nn.Sequential(
6
+ nn.Conv2d(in_c, out_c,kernel_size=4, stride=2, padding=1),
7
+ nn.BatchNorm2d(out_c),
8
+ nn.LeakyReLU(0.1, inplace=True),
9
+ )
10
+
11
+
12
+ def deconv4x4(in_c, out_c):
13
+ return nn.Sequential(
14
+ nn.ConvTranspose2d(in_c, out_c, kernel_size=4, stride=2, padding=1),
15
+ nn.BatchNorm2d(out_c),
16
+ nn.LeakyReLU(0.1, inplace=True),
17
+ )
18
+
19
+
20
+ class Hear_Net(nn.Module):
21
+ def __init__(self):
22
+ super(Hear_Net, self).__init__()
23
+ self.down1 = conv4x4(6, 64)
24
+ self.down2 = conv4x4(64, 128)
25
+ self.down3 = conv4x4(128, 256)
26
+ self.down4 = conv4x4(256, 512)
27
+ self.down5 = conv4x4(512, 512)
28
+
29
+ self.up1 = deconv4x4(512, 512)
30
+ self.up2 = deconv4x4(512*2, 256)
31
+ self.up3 = deconv4x4(256*2, 128)
32
+ self.up4 = deconv4x4(128*2, 64)
33
+ self.up5 = nn.Conv2d(64*2, 3, kernel_size=3, stride=1, padding=1)
34
+
35
+ def forward(self, x): # input:(B,6,256,256)
36
+ c1 = self.down1(x)
37
+ c2 = self.down2(c1)
38
+ c3 = self.down3(c2)
39
+ c4 = self.down4(c3)
40
+ c5 = self.down5(c4)
41
+
42
+ m1 = self.up1(c5)
43
+ m1 = torch.cat((c4, m1), dim=1)
44
+ m2 = self.up2(m1)
45
+ m2 = torch.cat((c3, m2), dim=1)
46
+ m3 = self.up3(m2)
47
+ m3 = torch.cat((c2, m3), dim=1)
48
+ m4 = self.up4(m3)
49
+ m4 = torch.cat((c1, m4), dim=1)
50
+
51
+ out = nn.functional.interpolate(m4, scale_factor=2, mode='bilinear', align_corners=True)
52
+ out = self.up5(out)
53
+ return torch.tanh(out) # output:(B,3,256,256)
54
+
55
+
56
+ if __name__ == '__main__':
57
+ y_cat = torch.randn(5, 6, 256, 256)
58
+ hear = Hear_Net()
59
+ y_st = hear(y_cat)
60
+ print(y_st.shape)
modules/layers/faceshifter/layers.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file only for testing mask regularzation.
3
+ If it works, it will be merged with `layers.py`.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class AADLayer(nn.Module):
12
+ def __init__(self, c_x, attr_c, c_id=256):
13
+ super(AADLayer, self).__init__()
14
+ self.attr_c = attr_c
15
+ self.c_id = c_id
16
+ self.c_x = c_x
17
+
18
+ self.conv1 = nn.Conv2d(
19
+ attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True
20
+ )
21
+ self.conv2 = nn.Conv2d(
22
+ attr_c, c_x, kernel_size=1, stride=1, padding=0, bias=True
23
+ )
24
+ self.fc1 = nn.Linear(c_id, c_x)
25
+ self.fc2 = nn.Linear(c_id, c_x)
26
+ self.norm = nn.InstanceNorm2d(c_x, affine=False)
27
+
28
+ self.conv_h = nn.Conv2d(c_x, 1, kernel_size=1, stride=1, padding=0, bias=True)
29
+
30
+ def forward(self, h_in, z_attr, z_id):
31
+ # h_in cxnxn
32
+ # zid 256x1x1
33
+ # zattr cxnxn
34
+ h = self.norm(h_in)
35
+ gamma_attr = self.conv1(z_attr)
36
+ beta_attr = self.conv2(z_attr)
37
+
38
+ gamma_id = self.fc1(z_id)
39
+ beta_id = self.fc2(z_id)
40
+ A = gamma_attr * h + beta_attr
41
+ gamma_id = gamma_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)
42
+ beta_id = beta_id.reshape(h.shape[0], self.c_x, 1, 1).expand_as(h)
43
+ I = gamma_id * h + beta_id
44
+
45
+ M = torch.sigmoid(self.conv_h(h))
46
+
47
+ out = (torch.ones_like(M).to(M.device) - M) * A + M * I
48
+ return out, torch.mean(torch.ones_like(M).to(M.device) - M, dim=[1, 2, 3])
49
+
50
+
51
+ class AAD_ResBlk(nn.Module):
52
+ def __init__(self, cin, cout, c_attr, c_id=256):
53
+ super(AAD_ResBlk, self).__init__()
54
+ self.cin = cin
55
+ self.cout = cout
56
+
57
+ self.AAD1 = AADLayer(cin, c_attr, c_id)
58
+ self.conv1 = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False)
59
+ self.relu1 = nn.ReLU(inplace=True)
60
+
61
+ self.AAD2 = AADLayer(cin, c_attr, c_id)
62
+ self.conv2 = nn.Conv2d(
63
+ cin, cout, kernel_size=3, stride=1, padding=1, bias=False
64
+ )
65
+ self.relu2 = nn.ReLU(inplace=True)
66
+
67
+ if cin != cout:
68
+ self.AAD3 = AADLayer(cin, c_attr, c_id)
69
+ self.conv3 = nn.Conv2d(
70
+ cin, cout, kernel_size=3, stride=1, padding=1, bias=False
71
+ )
72
+ self.relu3 = nn.ReLU(inplace=True)
73
+
74
+ def forward(self, h, z_attr, z_id):
75
+ x, m1_ = self.AAD1(h, z_attr, z_id)
76
+ x = self.relu1(x)
77
+ x = self.conv1(x)
78
+
79
+ x, m2_ = self.AAD2(x, z_attr, z_id)
80
+ x = self.relu2(x)
81
+ x = self.conv2(x)
82
+
83
+ m = m1_ + m2_
84
+
85
+ if self.cin != self.cout:
86
+ h, m3_ = self.AAD3(h, z_attr, z_id)
87
+ h = self.relu3(h)
88
+ h = self.conv3(h)
89
+ m += m3_
90
+ x = x + h
91
+
92
+ return x, m
93
+
94
+
95
+ def weight_init(m):
96
+ if isinstance(m, nn.Linear):
97
+ m.weight.data.normal_(0, 0.001)
98
+ m.bias.data.zero_()
99
+ if isinstance(m, nn.Conv2d):
100
+ nn.init.xavier_normal_(m.weight.data)
101
+
102
+ if isinstance(m, nn.ConvTranspose2d):
103
+ nn.init.xavier_normal_(m.weight.data)
104
+
105
+
106
+ def conv4x4(in_c, out_c, norm=nn.BatchNorm2d):
107
+ return nn.Sequential(
108
+ nn.Conv2d(
109
+ in_channels=in_c,
110
+ out_channels=out_c,
111
+ kernel_size=4,
112
+ stride=2,
113
+ padding=1,
114
+ bias=False,
115
+ ),
116
+ norm(out_c),
117
+ nn.LeakyReLU(0.1, inplace=True),
118
+ )
119
+
120
+
121
+ class deconv4x4(nn.Module):
122
+ def __init__(self, in_c, out_c, norm=nn.BatchNorm2d):
123
+ super(deconv4x4, self).__init__()
124
+ self.deconv = nn.ConvTranspose2d(
125
+ in_channels=in_c,
126
+ out_channels=out_c,
127
+ kernel_size=4,
128
+ stride=2,
129
+ padding=1,
130
+ bias=False,
131
+ )
132
+ self.bn = norm(out_c)
133
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
134
+
135
+ def forward(self, input, skip):
136
+ x = self.deconv(input)
137
+ x = self.bn(x)
138
+ x = self.lrelu(x)
139
+ return torch.cat((x, skip), dim=1)
140
+
141
+
142
+ class MLAttrEncoder(nn.Module):
143
+ def __init__(self, finetune=False, downup=False):
144
+ super(MLAttrEncoder, self).__init__()
145
+
146
+ self.downup = downup
147
+ if self.downup:
148
+ self.conv00 = conv4x4(3, 16)
149
+ self.conv01 = conv4x4(16, 32)
150
+ self.deconv7 = deconv4x4(64, 16)
151
+
152
+ self.conv1 = conv4x4(3, 32)
153
+ self.conv2 = conv4x4(32, 64)
154
+ self.conv3 = conv4x4(64, 128)
155
+ self.conv4 = conv4x4(128, 256)
156
+ self.conv5 = conv4x4(256, 512)
157
+ self.conv6 = conv4x4(512, 1024)
158
+ self.conv7 = conv4x4(1024, 1024)
159
+
160
+ self.deconv1 = deconv4x4(1024, 1024)
161
+ self.deconv2 = deconv4x4(2048, 512)
162
+ self.deconv3 = deconv4x4(1024, 256)
163
+ self.deconv4 = deconv4x4(512, 128)
164
+ self.deconv5 = deconv4x4(256, 64)
165
+ self.deconv6 = deconv4x4(128, 32)
166
+
167
+ self.apply(weight_init)
168
+
169
+ self.finetune = finetune
170
+ if finetune:
171
+ for name, param in self.named_parameters():
172
+ param.requires_grad = False
173
+ if self.downup:
174
+ self.conv00.requires_grad_(True)
175
+ self.conv01.requires_grad_(True)
176
+ self.deconv7.requires_grad_(True)
177
+
178
+ def forward(self, Xt):
179
+ if self.downup:
180
+ feat0 = self.conv00(Xt) # (16,256,256)
181
+ feat1 = self.conv01(feat0) # (32,128,128)
182
+ else:
183
+ feat0 = None
184
+ feat1 = self.conv1(Xt)
185
+ # 32x128x128
186
+
187
+ feat2 = self.conv2(feat1)
188
+ # 64x64x64
189
+ feat3 = self.conv3(feat2)
190
+ # 128x32x32
191
+ feat4 = self.conv4(feat3)
192
+ # 256x16xx16
193
+ feat5 = self.conv5(feat4)
194
+ # 512x8x8
195
+ feat6 = self.conv6(feat5)
196
+ # 1024x4x4
197
+
198
+ if self.downup:
199
+ z_attr1 = self.conv7(feat6)
200
+ # 1024x2x2
201
+ z_attr2 = self.deconv1(z_attr1, feat6)
202
+ z_attr3 = self.deconv2(z_attr2, feat5)
203
+ z_attr4 = self.deconv3(z_attr3, feat4)
204
+ z_attr5 = self.deconv4(z_attr4, feat3)
205
+ z_attr6 = self.deconv5(z_attr5, feat2)
206
+ z_attr7 = self.deconv6(z_attr6, feat1) # (128,64,64)+(32,128,128)->(64,128,128)
207
+ z_attr8 = self.deconv7(z_attr7, feat0) # (64,128,128)+(16,256,256)->(32,256,256)
208
+ z_attr9 = F.interpolate(
209
+ z_attr8, scale_factor=2, mode="bilinear", align_corners=True
210
+ ) # (32,512,512)
211
+ return (
212
+ z_attr1,
213
+ z_attr2,
214
+ z_attr3,
215
+ z_attr4,
216
+ z_attr5,
217
+ z_attr6,
218
+ z_attr7,
219
+ z_attr8,
220
+ z_attr9
221
+ )
222
+ else:
223
+ z_attr1 = self.conv7(feat6)
224
+ # 1024x2x2
225
+ z_attr2 = self.deconv1(z_attr1, feat6)
226
+ z_attr3 = self.deconv2(z_attr2, feat5)
227
+ z_attr4 = self.deconv3(z_attr3, feat4)
228
+ z_attr5 = self.deconv4(z_attr4, feat3)
229
+ z_attr6 = self.deconv5(z_attr5, feat2)
230
+ z_attr7 = self.deconv6(z_attr6, feat1)
231
+ z_attr8 = F.interpolate(
232
+ z_attr7, scale_factor=2, mode="bilinear", align_corners=True
233
+ )
234
+ return (
235
+ z_attr1,
236
+ z_attr2,
237
+ z_attr3,
238
+ z_attr4,
239
+ z_attr5,
240
+ z_attr6,
241
+ z_attr7,
242
+ z_attr8,
243
+ )
244
+
245
+
246
+ class AADGenerator(nn.Module):
247
+ def __init__(self, c_id=256, finetune=False, downup=False):
248
+ super(AADGenerator, self).__init__()
249
+ self.up1 = nn.ConvTranspose2d(c_id, 1024, kernel_size=2, stride=1, padding=0)
250
+ self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, c_id)
251
+ self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, c_id)
252
+ self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, c_id)
253
+ self.AADBlk4 = AAD_ResBlk(1024, 512, 512, c_id)
254
+ self.AADBlk5 = AAD_ResBlk(512, 256, 256, c_id)
255
+ self.AADBlk6 = AAD_ResBlk(256, 128, 128, c_id)
256
+ self.AADBlk7 = AAD_ResBlk(128, 64, 64, c_id)
257
+ self.AADBlk8 = AAD_ResBlk(64, 3, 64, c_id)
258
+
259
+ self.downup = downup
260
+ if downup:
261
+ self.AADBlk8_0 = AAD_ResBlk(64, 32, 32, c_id)
262
+ self.AADBlk8_1 = AAD_ResBlk(32, 3, 32, c_id)
263
+
264
+ self.apply(weight_init)
265
+
266
+ if finetune:
267
+ for name, param in self.named_parameters():
268
+ param.requires_grad = False
269
+ self.AADBlk8_0.requires_grad_(True)
270
+ self.AADBlk8_1.requires_grad_(True)
271
+
272
+ def forward(self, z_attr, z_id):
273
+ m = self.up1(z_id.reshape(z_id.shape[0], -1, 1, 1))
274
+ scale= z_attr[0].shape[2] // 2 # adaptive support for 512x512, 1024x1024
275
+ m = F.interpolate(m, scale_factor=scale, mode='bilinear', align_corners=True)
276
+ m2, m2_ = self.AADBlk1(m, z_attr[0], z_id)
277
+ m2 = F.interpolate(
278
+ m2,
279
+ scale_factor=2,
280
+ mode="bilinear",
281
+ align_corners=True,
282
+ )
283
+ m3, m3_ = self.AADBlk2(m2, z_attr[1], z_id)
284
+ m3 = F.interpolate(
285
+ m3,
286
+ scale_factor=2,
287
+ mode="bilinear",
288
+ align_corners=True,
289
+ )
290
+ m4, m4_ = self.AADBlk3(m3, z_attr[2], z_id)
291
+ m4 = F.interpolate(
292
+ m4,
293
+ scale_factor=2,
294
+ mode="bilinear",
295
+ align_corners=True,
296
+ )
297
+ m5, m5_ = self.AADBlk4(m4, z_attr[3], z_id)
298
+ m5 = F.interpolate(
299
+ m5,
300
+ scale_factor=2,
301
+ mode="bilinear",
302
+ align_corners=True,
303
+ )
304
+ m6, m6_ = self.AADBlk5(m5, z_attr[4], z_id)
305
+ m6 = F.interpolate(
306
+ m6,
307
+ scale_factor=2,
308
+ mode="bilinear",
309
+ align_corners=True,
310
+ )
311
+ m7, m7_ = self.AADBlk6(m6, z_attr[5], z_id)
312
+ m7 = F.interpolate(
313
+ m7,
314
+ scale_factor=2,
315
+ mode="bilinear",
316
+ align_corners=True,
317
+ )
318
+ m8, m8_ = self.AADBlk7(m7, z_attr[6], z_id)
319
+ m8 = F.interpolate(
320
+ m8,
321
+ scale_factor=2,
322
+ mode="bilinear",
323
+ align_corners=True,
324
+ )
325
+
326
+ if self.downup:
327
+ y0, m9_ = self.AADBlk8_0(m8, z_attr[7], z_id)
328
+ y0 = F.interpolate(y0, scale_factor=2, mode='bilinear', align_corners=True)
329
+ y1, m10_ = self.AADBlk8_1(y0, z_attr[8], z_id)
330
+ y = torch.tanh(y1)
331
+ else:
332
+ y, m9_ = self.AADBlk8(m8, z_attr[7], z_id)
333
+ y = torch.tanh(y)
334
+ return y # , m # yuange
335
+
336
+
337
+ class AEI_Net(nn.Module):
338
+ def __init__(self, c_id=512, finetune=False, downup=False):
339
+ super(AEI_Net, self).__init__()
340
+ self.encoder = MLAttrEncoder(finetune=finetune, downup=downup)
341
+ self.generator = AADGenerator(c_id, finetune=finetune, downup=downup)
342
+
343
+ def forward(self, Xt, z_id):
344
+ attr = self.encoder(Xt)
345
+ Y = self.generator(attr, z_id) # yuange
346
+ return Y, attr
347
+
348
+ def get_attr(self, X):
349
+ return self.encoder(X)
350
+
351
+ def trainable_params(self):
352
+ train_params = []
353
+ for param in self.parameters():
354
+ if param.requires_grad:
355
+ train_params.append(param)
356
+ return train_params
357
+
358
+
359
+ if __name__ == "__main__":
360
+ aie = AEI_Net(512).eval()
361
+ x = aie(torch.randn(1, 3, 512, 512), torch.randn(1, 512))
362
+
363
+
364
+ # def numel(m: torch.nn.Module, only_trainable: bool = False):
365
+ # """
366
+ # returns the total number of parameters used by `m` (only counting
367
+ # shared parameters once); if `only_trainable` is True, then only
368
+ # includes parameters with `requires_grad = True`
369
+ # """
370
+ # parameters = list(m.parameters())
371
+ # if only_trainable:
372
+ # parameters = [p for p in parameters if p.requires_grad]
373
+ # unique = {p.data_ptr(): p for p in parameters}.values()
374
+ # return sum(p.numel() for p in unique)
375
+ #
376
+ #
377
+ # print(numel(aie, True))
378
+ # print(x[0].size())
379
+ # print(len(x[-1]))
380
+
381
+
382
+ import thop
383
+
384
+ img = torch.randn(1, 3, 256, 256)
385
+ latent = torch.randn(1, 512)
386
+ net = aie
387
+ flops, params = thop.profile(net, inputs=(img, latent), verbose=False)
388
+ print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))
modules/layers/simswap/base_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import sys
4
+
5
+
6
+ class BaseModel(torch.nn.Module):
7
+ def name(self):
8
+ return 'BaseModel'
9
+
10
+ def initialize(self, opt):
11
+ self.opt = opt
12
+ self.gpu_ids = opt.gpu_ids
13
+ self.isTrain = opt.isTrain
14
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
15
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
16
+
17
+ def set_input(self, input):
18
+ self.input = input
19
+
20
+ def forward(self):
21
+ pass
22
+
23
+ # used in test time, no backprop
24
+ def test(self):
25
+ pass
26
+
27
+ def get_image_paths(self):
28
+ pass
29
+
30
+ def optimize_parameters(self):
31
+ pass
32
+
33
+ def get_current_visuals(self):
34
+ return self.input
35
+
36
+ def get_current_errors(self):
37
+ return {}
38
+
39
+ def save(self, label):
40
+ pass
41
+
42
+ # helper saving function that can be used by subclasses
43
+ def save_network(self, network, network_label, epoch_label, gpu_ids=None):
44
+ save_filename = '{}_net_{}.pth'.format(epoch_label, network_label)
45
+ save_path = os.path.join(self.save_dir, save_filename)
46
+ torch.save(network.cpu().state_dict(), save_path)
47
+ if torch.cuda.is_available():
48
+ network.cuda()
49
+
50
+ def save_optim(self, network, network_label, epoch_label, gpu_ids=None):
51
+ save_filename = '{}_optim_{}.pth'.format(epoch_label, network_label)
52
+ save_path = os.path.join(self.save_dir, save_filename)
53
+ torch.save(network.state_dict(), save_path)
54
+
55
+ # helper loading function that can be used by subclasses
56
+ def load_network(self, network, network_label, epoch_label, save_dir=''):
57
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
58
+ if not save_dir:
59
+ save_dir = self.save_dir
60
+ save_path = os.path.join(save_dir, save_filename)
61
+ if not os.path.isfile(save_path):
62
+ print('%s not exists yet!' % save_path)
63
+ if network_label == 'G':
64
+ raise ('Generator must exist!')
65
+ else:
66
+ # network.load_state_dict(torch.load(save_path))
67
+ try:
68
+ network.load_state_dict(torch.load(save_path))
69
+ except:
70
+ pretrained_dict = torch.load(save_path)
71
+ model_dict = network.state_dict()
72
+ try:
73
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
74
+ network.load_state_dict(pretrained_dict)
75
+ if self.opt.verbose:
76
+ print(
77
+ 'Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
78
+ except:
79
+ print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
80
+ for k, v in pretrained_dict.items():
81
+ if v.size() == model_dict[k].size():
82
+ model_dict[k] = v
83
+
84
+ if sys.version_info >= (3, 0):
85
+ not_initialized = set()
86
+ else:
87
+ from sets import Set
88
+ not_initialized = Set()
89
+
90
+ for k, v in model_dict.items():
91
+ if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
92
+ not_initialized.add(k.split('.')[0])
93
+
94
+ print(sorted(not_initialized))
95
+ network.load_state_dict(model_dict)
96
+
97
+ # helper loading function that can be used by subclasses
98
+ def load_optim(self, network, network_label, epoch_label, save_dir=''):
99
+ save_filename = '%s_optim_%s.pth' % (epoch_label, network_label)
100
+ if not save_dir:
101
+ save_dir = self.save_dir
102
+ save_path = os.path.join(save_dir, save_filename)
103
+ if not os.path.isfile(save_path):
104
+ print('%s not exists yet!' % save_path)
105
+ if network_label == 'G':
106
+ raise ('Generator must exist!')
107
+ else:
108
+ # network.load_state_dict(torch.load(save_path))
109
+ try:
110
+ network.load_state_dict(torch.load(save_path, map_location=torch.device("cpu")))
111
+ except:
112
+ pretrained_dict = torch.load(save_path, map_location=torch.device("cpu"))
113
+ model_dict = network.state_dict()
114
+ try:
115
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
116
+ network.load_state_dict(pretrained_dict)
117
+ if self.opt.verbose:
118
+ print(
119
+ 'Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
120
+ except:
121
+ print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
122
+ for k, v in pretrained_dict.items():
123
+ if v.size() == model_dict[k].size():
124
+ model_dict[k] = v
125
+
126
+ if sys.version_info >= (3, 0):
127
+ not_initialized = set()
128
+ else:
129
+ from sets import Set
130
+ not_initialized = Set()
131
+
132
+ for k, v in model_dict.items():
133
+ if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
134
+ not_initialized.add(k.split('.')[0])
135
+
136
+ print(sorted(not_initialized))
137
+ network.load_state_dict(model_dict)
138
+
139
+ def update_learning_rate(self):
140
+ pass
modules/layers/simswap/fs_networks_fix.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import kornia
10
+
11
+
12
+ class InstanceNorm(nn.Module):
13
+ def __init__(self, epsilon=1e-8):
14
+ """
15
+ @notice: avoid in-place ops.
16
+ https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
17
+ """
18
+ super(InstanceNorm, self).__init__()
19
+ self.epsilon = epsilon
20
+
21
+ def forward(self, x):
22
+ x = x - torch.mean(x, (2, 3), True)
23
+ tmp = torch.mul(x, x) # or x ** 2
24
+ tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
25
+ return x * tmp
26
+
27
+ class ApplyStyle(nn.Module):
28
+ """
29
+ @ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
30
+ """
31
+ def __init__(self, latent_size, channels):
32
+ super(ApplyStyle, self).__init__()
33
+ self.linear = nn.Linear(latent_size, channels * 2)
34
+
35
+ def forward(self, x, latent):
36
+ style = self.linear(latent) # style => [batch_size, n_channels*2]
37
+ shape = [-1, 2, x.size(1), 1, 1]
38
+ style = style.view(shape) # [batch_size, 2, n_channels, ...]
39
+ #x = x * (style[:, 0] + 1.) + style[:, 1]
40
+ x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
41
+ return x
42
+
43
+ class ResnetBlock_Adain(nn.Module):
44
+ def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
45
+ super(ResnetBlock_Adain, self).__init__()
46
+
47
+ p = 0
48
+ conv1 = []
49
+ if padding_type == 'reflect':
50
+ conv1 += [nn.ReflectionPad2d(1)]
51
+ elif padding_type == 'replicate':
52
+ conv1 += [nn.ReplicationPad2d(1)]
53
+ elif padding_type == 'zero':
54
+ p = 1
55
+ else:
56
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
57
+ conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
58
+ self.conv1 = nn.Sequential(*conv1)
59
+ self.style1 = ApplyStyle(latent_size, dim)
60
+ self.act1 = activation
61
+
62
+ p = 0
63
+ conv2 = []
64
+ if padding_type == 'reflect':
65
+ conv2 += [nn.ReflectionPad2d(1)]
66
+ elif padding_type == 'replicate':
67
+ conv2 += [nn.ReplicationPad2d(1)]
68
+ elif padding_type == 'zero':
69
+ p = 1
70
+ else:
71
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
72
+ conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
73
+ self.conv2 = nn.Sequential(*conv2)
74
+ self.style2 = ApplyStyle(latent_size, dim)
75
+
76
+
77
+ def forward(self, x, dlatents_in_slice):
78
+ y = self.conv1(x)
79
+ y = self.style1(y, dlatents_in_slice)
80
+ y = self.act1(y)
81
+ y = self.conv2(y)
82
+ y = self.style2(y, dlatents_in_slice)
83
+ out = x + y
84
+ return out
85
+
86
+
87
+
88
+ class Generator_Adain_Upsample(nn.Module):
89
+ def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
90
+ norm_layer=nn.BatchNorm2d,
91
+ padding_type='reflect',
92
+ mouth_net_param: dict = None,
93
+ ):
94
+ assert (n_blocks >= 0)
95
+ super(Generator_Adain_Upsample, self).__init__()
96
+
97
+ self.latent_size = latent_size
98
+
99
+ self.mouth_net_param = mouth_net_param
100
+ if mouth_net_param.get('use'):
101
+ self.latent_size += mouth_net_param.get('feature_dim')
102
+
103
+ activation = nn.ReLU(True)
104
+
105
+ self.deep = deep
106
+
107
+ self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
108
+ norm_layer(64), activation)
109
+ ### downsample
110
+ self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
111
+ norm_layer(128), activation)
112
+ self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
113
+ norm_layer(256), activation)
114
+ self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
115
+ norm_layer(512), activation)
116
+
117
+ if self.deep:
118
+ self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
119
+ norm_layer(512), activation)
120
+
121
+ ### resnet blocks
122
+ BN = []
123
+ for i in range(n_blocks):
124
+ BN += [
125
+ ResnetBlock_Adain(512, latent_size=self.latent_size,
126
+ padding_type=padding_type, activation=activation)]
127
+ self.BottleNeck = nn.Sequential(*BN)
128
+
129
+ if self.deep:
130
+ self.up4 = nn.Sequential(
131
+ nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
132
+ nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
133
+ nn.BatchNorm2d(512), activation
134
+ )
135
+ self.up3 = nn.Sequential(
136
+ nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
137
+ nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
138
+ nn.BatchNorm2d(256), activation
139
+ )
140
+ self.up2 = nn.Sequential(
141
+ nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
142
+ nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
143
+ nn.BatchNorm2d(128), activation
144
+ )
145
+ self.up1 = nn.Sequential(
146
+ nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
147
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
148
+ nn.BatchNorm2d(64), activation
149
+ )
150
+ self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0))
151
+
152
+ self.register_buffer(
153
+ name="trans_matrix",
154
+ tensor=torch.tensor(
155
+ [
156
+ [
157
+ [1.07695457, -0.03625215, -1.56352194],
158
+ [0.03625215, 1.07695457, -5.32134629],
159
+ ]
160
+ ],
161
+ requires_grad=False,
162
+ ).float(),
163
+ )
164
+
165
+ def forward(self, source, target, net_arc, mouth_net=None):
166
+ x = target # 3*224*224
167
+ if net_arc is None:
168
+ id_vector = source
169
+ else:
170
+ with torch.no_grad():
171
+ ''' 1. get id '''
172
+ # M = self.trans_matrix.repeat(source.size()[0], 1, 1)
173
+ # source = kornia.geometry.transform.warp_affine(source, M, (256, 256))
174
+ resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True)
175
+ id_vector = F.normalize(net_arc(resize_input), dim=-1, p=2)
176
+
177
+ ''' 2. get mouth feature '''
178
+ if mouth_net is not None:
179
+ w1, h1, w2, h2 = self.mouth_net_param.get('crop_param')
180
+ mouth_input = resize_input[:, :, h1:h2, w1:w2]
181
+ mouth_feat = mouth_net(mouth_input)
182
+ id_vector = torch.cat([id_vector, mouth_feat], dim=-1) # (B,dim_id+dim_mouth)
183
+
184
+ skip1 = self.first_layer(x)
185
+ skip2 = self.down1(skip1)
186
+ skip3 = self.down2(skip2)
187
+ if self.deep:
188
+ skip4 = self.down3(skip3)
189
+ x = self.down4(skip4)
190
+ else:
191
+ x = self.down3(skip3)
192
+ bot = []
193
+ bot.append(x)
194
+ features = []
195
+ for i in range(len(self.BottleNeck)):
196
+ x = self.BottleNeck[i](x, id_vector)
197
+ bot.append(x)
198
+
199
+ if self.deep:
200
+ x = self.up4(x)
201
+ features.append(x)
202
+ x = self.up3(x)
203
+ features.append(x)
204
+ x = self.up2(x)
205
+ features.append(x)
206
+ x = self.up1(x)
207
+ features.append(x)
208
+ x = self.last_layer(x)
209
+ # x = (x + 1) / 2
210
+
211
+ # return x, bot, features, dlatents
212
+ return x
213
+
214
+
215
+ if __name__ == "__main__":
216
+ import thop
217
+
218
+ img = torch.randn(1, 3, 256, 256)
219
+ latent = torch.randn(1, 512)
220
+ net = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9,
221
+ mouth_net_param={"use": False})
222
+ flops, params = thop.profile(net, inputs=(latent, img, None, None), verbose=False)
223
+ print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))
modules/layers/simswap/pg_modules/blocks.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils import spectral_norm
6
+
7
+
8
+ ### single layers
9
+
10
+
11
+ def conv2d(*args, **kwargs):
12
+ return spectral_norm(nn.Conv2d(*args, **kwargs))
13
+
14
+
15
+ def convTranspose2d(*args, **kwargs):
16
+ return spectral_norm(nn.ConvTranspose2d(*args, **kwargs))
17
+
18
+
19
+ def embedding(*args, **kwargs):
20
+ return spectral_norm(nn.Embedding(*args, **kwargs))
21
+
22
+
23
+ def linear(*args, **kwargs):
24
+ return spectral_norm(nn.Linear(*args, **kwargs))
25
+
26
+
27
+ def NormLayer(c, mode='batch'):
28
+ if mode == 'group':
29
+ return nn.GroupNorm(c//2, c)
30
+ elif mode == 'batch':
31
+ return nn.BatchNorm2d(c)
32
+
33
+
34
+ ### Activations
35
+
36
+
37
+ class GLU(nn.Module):
38
+ def forward(self, x):
39
+ nc = x.size(1)
40
+ assert nc % 2 == 0, 'channels dont divide 2!'
41
+ nc = int(nc/2)
42
+ return x[:, :nc] * torch.sigmoid(x[:, nc:])
43
+
44
+
45
+ class Swish(nn.Module):
46
+ def forward(self, feat):
47
+ return feat * torch.sigmoid(feat)
48
+
49
+
50
+ ### Upblocks
51
+
52
+
53
+ class InitLayer(nn.Module):
54
+ def __init__(self, nz, channel, sz=4):
55
+ super().__init__()
56
+
57
+ self.init = nn.Sequential(
58
+ convTranspose2d(nz, channel*2, sz, 1, 0, bias=False),
59
+ NormLayer(channel*2),
60
+ GLU(),
61
+ )
62
+
63
+ def forward(self, noise):
64
+ noise = noise.view(noise.shape[0], -1, 1, 1)
65
+ return self.init(noise)
66
+
67
+
68
+ def UpBlockSmall(in_planes, out_planes):
69
+ block = nn.Sequential(
70
+ nn.Upsample(scale_factor=2, mode='nearest'),
71
+ conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
72
+ NormLayer(out_planes*2), GLU())
73
+ return block
74
+
75
+
76
+ class UpBlockSmallCond(nn.Module):
77
+ def __init__(self, in_planes, out_planes, z_dim):
78
+ super().__init__()
79
+ self.in_planes = in_planes
80
+ self.out_planes = out_planes
81
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
82
+ self.conv = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
83
+
84
+ which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
85
+ self.bn = which_bn(2*out_planes)
86
+ self.act = GLU()
87
+
88
+ def forward(self, x, c):
89
+ x = self.up(x)
90
+ x = self.conv(x)
91
+ x = self.bn(x, c)
92
+ x = self.act(x)
93
+ return x
94
+
95
+
96
+ def UpBlockBig(in_planes, out_planes):
97
+ block = nn.Sequential(
98
+ nn.Upsample(scale_factor=2, mode='nearest'),
99
+ conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False),
100
+ NoiseInjection(),
101
+ NormLayer(out_planes*2), GLU(),
102
+ conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False),
103
+ NoiseInjection(),
104
+ NormLayer(out_planes*2), GLU()
105
+ )
106
+ return block
107
+
108
+
109
+ class UpBlockBigCond(nn.Module):
110
+ def __init__(self, in_planes, out_planes, z_dim):
111
+ super().__init__()
112
+ self.in_planes = in_planes
113
+ self.out_planes = out_planes
114
+ self.up = nn.Upsample(scale_factor=2, mode='nearest')
115
+ self.conv1 = conv2d(in_planes, out_planes*2, 3, 1, 1, bias=False)
116
+ self.conv2 = conv2d(out_planes, out_planes*2, 3, 1, 1, bias=False)
117
+
118
+ which_bn = functools.partial(CCBN, which_linear=linear, input_size=z_dim)
119
+ self.bn1 = which_bn(2*out_planes)
120
+ self.bn2 = which_bn(2*out_planes)
121
+ self.act = GLU()
122
+ self.noise = NoiseInjection()
123
+
124
+ def forward(self, x, c):
125
+ # block 1
126
+ x = self.up(x)
127
+ x = self.conv1(x)
128
+ x = self.noise(x)
129
+ x = self.bn1(x, c)
130
+ x = self.act(x)
131
+
132
+ # block 2
133
+ x = self.conv2(x)
134
+ x = self.noise(x)
135
+ x = self.bn2(x, c)
136
+ x = self.act(x)
137
+
138
+ return x
139
+
140
+
141
+ class SEBlock(nn.Module):
142
+ def __init__(self, ch_in, ch_out):
143
+ super().__init__()
144
+ self.main = nn.Sequential(
145
+ nn.AdaptiveAvgPool2d(4),
146
+ conv2d(ch_in, ch_out, 4, 1, 0, bias=False),
147
+ Swish(),
148
+ conv2d(ch_out, ch_out, 1, 1, 0, bias=False),
149
+ nn.Sigmoid(),
150
+ )
151
+
152
+ def forward(self, feat_small, feat_big):
153
+ return feat_big * self.main(feat_small)
154
+
155
+
156
+ ### Downblocks
157
+
158
+
159
+ class SeparableConv2d(nn.Module):
160
+ def __init__(self, in_channels, out_channels, kernel_size, bias=False):
161
+ super(SeparableConv2d, self).__init__()
162
+ self.depthwise = conv2d(in_channels, in_channels, kernel_size=kernel_size,
163
+ groups=in_channels, bias=bias, padding=1)
164
+ self.pointwise = conv2d(in_channels, out_channels,
165
+ kernel_size=1, bias=bias)
166
+
167
+ def forward(self, x):
168
+ out = self.depthwise(x)
169
+ out = self.pointwise(out)
170
+ return out
171
+
172
+
173
+ class DownBlock(nn.Module):
174
+ def __init__(self, in_planes, out_planes, separable=False):
175
+ super().__init__()
176
+ if not separable:
177
+ self.main = nn.Sequential(
178
+ conv2d(in_planes, out_planes, 4, 2, 1),
179
+ NormLayer(out_planes),
180
+ nn.LeakyReLU(0.2, inplace=True),
181
+ )
182
+ else:
183
+ self.main = nn.Sequential(
184
+ SeparableConv2d(in_planes, out_planes, 3),
185
+ NormLayer(out_planes),
186
+ nn.LeakyReLU(0.2, inplace=True),
187
+ nn.AvgPool2d(2, 2),
188
+ )
189
+
190
+ def forward(self, feat):
191
+ return self.main(feat)
192
+
193
+
194
+ class DownBlockPatch(nn.Module):
195
+ def __init__(self, in_planes, out_planes, separable=False):
196
+ super().__init__()
197
+ self.main = nn.Sequential(
198
+ DownBlock(in_planes, out_planes, separable),
199
+ conv2d(out_planes, out_planes, 1, 1, 0, bias=False),
200
+ NormLayer(out_planes),
201
+ nn.LeakyReLU(0.2, inplace=True),
202
+ )
203
+
204
+ def forward(self, feat):
205
+ return self.main(feat)
206
+
207
+
208
+ ### CSM
209
+
210
+
211
+ class ResidualConvUnit(nn.Module):
212
+ def __init__(self, cin, activation, bn):
213
+ super().__init__()
214
+ self.conv = nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=True)
215
+ self.skip_add = nn.quantized.FloatFunctional()
216
+
217
+ def forward(self, x):
218
+ return self.skip_add.add(self.conv(x), x)
219
+
220
+
221
+ class FeatureFusionBlock(nn.Module):
222
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, lowest=False):
223
+ super().__init__()
224
+
225
+ self.deconv = deconv
226
+ self.align_corners = align_corners
227
+
228
+ self.expand = expand
229
+ out_features = features
230
+ if self.expand==True:
231
+ out_features = features//2
232
+
233
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
234
+ self.skip_add = nn.quantized.FloatFunctional()
235
+
236
+ def forward(self, *xs):
237
+ output = xs[0]
238
+
239
+ if len(xs) == 2:
240
+ output = self.skip_add.add(output, xs[1])
241
+
242
+ output = nn.functional.interpolate(
243
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
244
+ )
245
+
246
+ output = self.out_conv(output)
247
+
248
+ return output
249
+
250
+
251
+ ### Misc
252
+
253
+
254
+ class NoiseInjection(nn.Module):
255
+ def __init__(self):
256
+ super().__init__()
257
+ self.weight = nn.Parameter(torch.zeros(1), requires_grad=True)
258
+
259
+ def forward(self, feat, noise=None):
260
+ if noise is None:
261
+ batch, _, height, width = feat.shape
262
+ noise = torch.randn(batch, 1, height, width).to(feat.device)
263
+
264
+ return feat + self.weight * noise
265
+
266
+
267
+ class CCBN(nn.Module):
268
+ ''' conditional batchnorm '''
269
+ def __init__(self, output_size, input_size, which_linear, eps=1e-5, momentum=0.1):
270
+ super().__init__()
271
+ self.output_size, self.input_size = output_size, input_size
272
+
273
+ # Prepare gain and bias layers
274
+ self.gain = which_linear(input_size, output_size)
275
+ self.bias = which_linear(input_size, output_size)
276
+
277
+ # epsilon to avoid dividing by 0
278
+ self.eps = eps
279
+ # Momentum
280
+ self.momentum = momentum
281
+
282
+ self.register_buffer('stored_mean', torch.zeros(output_size))
283
+ self.register_buffer('stored_var', torch.ones(output_size))
284
+
285
+ def forward(self, x, y):
286
+ # Calculate class-conditional gains and biases
287
+ gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
288
+ bias = self.bias(y).view(y.size(0), -1, 1, 1)
289
+ out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
290
+ self.training, 0.1, self.eps)
291
+ return out * gain + bias
292
+
293
+
294
+ class Interpolate(nn.Module):
295
+ """Interpolation module."""
296
+
297
+ def __init__(self, size, mode='bilinear', align_corners=False):
298
+ """Init.
299
+ Args:
300
+ scale_factor (float): scaling
301
+ mode (str): interpolation mode
302
+ """
303
+ super(Interpolate, self).__init__()
304
+
305
+ self.interp = nn.functional.interpolate
306
+ self.size = size
307
+ self.mode = mode
308
+ self.align_corners = align_corners
309
+
310
+ def forward(self, x):
311
+ """Forward pass.
312
+ Args:
313
+ x (tensor): input
314
+ Returns:
315
+ tensor: interpolated data
316
+ """
317
+
318
+ x = self.interp(
319
+ x,
320
+ size=self.size,
321
+ mode=self.mode,
322
+ align_corners=self.align_corners,
323
+ )
324
+
325
+ return x
modules/layers/simswap/pg_modules/diffaug.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Differentiable Augmentation for Data-Efficient GAN Training
2
+ # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3
+ # https://arxiv.org/pdf/2006.10738
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def DiffAugment(x, policy='', channels_first=True):
10
+ if policy:
11
+ if not channels_first:
12
+ x = x.permute(0, 3, 1, 2)
13
+ for p in policy.split(','):
14
+ for f in AUGMENT_FNS[p]:
15
+ x = f(x)
16
+ if not channels_first:
17
+ x = x.permute(0, 2, 3, 1)
18
+ x = x.contiguous()
19
+ return x
20
+
21
+
22
+ def rand_brightness(x):
23
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24
+ return x
25
+
26
+
27
+ def rand_saturation(x):
28
+ x_mean = x.mean(dim=1, keepdim=True)
29
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30
+ return x
31
+
32
+
33
+ def rand_contrast(x):
34
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35
+ x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36
+ return x
37
+
38
+
39
+ def rand_translation(x, ratio=0.125):
40
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41
+ translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42
+ translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43
+ grid_batch, grid_x, grid_y = torch.meshgrid(
44
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
45
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
46
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
47
+ )
48
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51
+ x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
52
+ return x
53
+
54
+
55
+ def rand_cutout(x, ratio=0.2):
56
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57
+ offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58
+ offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59
+ grid_batch, grid_x, grid_y = torch.meshgrid(
60
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
61
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63
+ )
64
+ grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65
+ grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67
+ mask[grid_batch, grid_x, grid_y] = 0
68
+ x = x * mask.unsqueeze(1)
69
+ return x
70
+
71
+
72
+ AUGMENT_FNS = {
73
+ 'color': [rand_brightness, rand_saturation, rand_contrast],
74
+ 'translation': [rand_translation],
75
+ 'cutout': [rand_cutout],
76
+ }
modules/layers/simswap/pg_modules/projected_discriminator.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from modules.layers.simswap.pg_modules.blocks import DownBlock, DownBlockPatch, conv2d
7
+ from modules.layers.simswap.pg_modules.projector import F_RandomProj
8
+ from modules.layers.simswap.pg_modules.diffaug import DiffAugment
9
+
10
+
11
+ class SingleDisc(nn.Module):
12
+ def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False):
13
+ super().__init__()
14
+ channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
15
+ 256: 32, 512: 16, 1024: 8}
16
+
17
+ # interpolate for start sz that are not powers of two
18
+ if start_sz not in channel_dict.keys():
19
+ sizes = np.array(list(channel_dict.keys()))
20
+ start_sz = sizes[np.argmin(abs(sizes - start_sz))]
21
+ self.start_sz = start_sz
22
+
23
+ # if given ndf, allocate all layers with the same ndf
24
+ if ndf is None:
25
+ nfc = channel_dict
26
+ else:
27
+ nfc = {k: ndf for k, v in channel_dict.items()}
28
+
29
+ # for feature map discriminators with nfc not in channel_dict
30
+ # this is the case for the pretrained backbone (midas.pretrained)
31
+ if nc is not None and head is None:
32
+ nfc[start_sz] = nc
33
+
34
+ layers = []
35
+
36
+ # Head if the initial input is the full modality
37
+ if head:
38
+ layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
39
+ nn.LeakyReLU(0.2, inplace=True)]
40
+
41
+ # Down Blocks
42
+ DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
43
+ while start_sz > end_sz:
44
+ layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
45
+ start_sz = start_sz // 2
46
+
47
+ layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False))
48
+ self.main = nn.Sequential(*layers)
49
+
50
+ def forward(self, x, c):
51
+ return self.main(x)
52
+
53
+
54
+ class SingleDiscCond(nn.Module):
55
+ def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128):
56
+ super().__init__()
57
+ self.cmap_dim = cmap_dim
58
+
59
+ # midas channels
60
+ channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64,
61
+ 256: 32, 512: 16, 1024: 8}
62
+
63
+ # interpolate for start sz that are not powers of two
64
+ if start_sz not in channel_dict.keys():
65
+ sizes = np.array(list(channel_dict.keys()))
66
+ start_sz = sizes[np.argmin(abs(sizes - start_sz))]
67
+ self.start_sz = start_sz
68
+
69
+ # if given ndf, allocate all layers with the same ndf
70
+ if ndf is None:
71
+ nfc = channel_dict
72
+ else:
73
+ nfc = {k: ndf for k, v in channel_dict.items()}
74
+
75
+ # for feature map discriminators with nfc not in channel_dict
76
+ # this is the case for the pretrained backbone (midas.pretrained)
77
+ if nc is not None and head is None:
78
+ nfc[start_sz] = nc
79
+
80
+ layers = []
81
+
82
+ # Head if the initial input is the full modality
83
+ if head:
84
+ layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False),
85
+ nn.LeakyReLU(0.2, inplace=True)]
86
+
87
+ # Down Blocks
88
+ DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable)
89
+ while start_sz > end_sz:
90
+ layers.append(DB(nfc[start_sz], nfc[start_sz//2]))
91
+ start_sz = start_sz // 2
92
+ self.main = nn.Sequential(*layers)
93
+
94
+ # additions for conditioning on class information
95
+ self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False)
96
+ self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim)
97
+ self.embed_proj = nn.Sequential(
98
+ nn.Linear(self.embed.embedding_dim, self.cmap_dim),
99
+ nn.LeakyReLU(0.2, inplace=True),
100
+ )
101
+
102
+ def forward(self, x, c):
103
+ h = self.main(x)
104
+ out = self.cls(h)
105
+
106
+ # conditioning via projection
107
+ cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1)
108
+ out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
109
+
110
+ return out
111
+
112
+
113
+ class MultiScaleD(nn.Module):
114
+ def __init__(
115
+ self,
116
+ channels,
117
+ resolutions,
118
+ num_discs=4,
119
+ proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
120
+ cond=0,
121
+ separable=False,
122
+ patch=False,
123
+ **kwargs,
124
+ ):
125
+ super().__init__()
126
+
127
+ assert num_discs in [1, 2, 3, 4]
128
+
129
+ # the first disc is on the lowest level of the backbone
130
+ self.disc_in_channels = channels[:num_discs]
131
+ self.disc_in_res = resolutions[:num_discs]
132
+ Disc = SingleDiscCond if cond else SingleDisc
133
+
134
+ mini_discs = []
135
+ for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)):
136
+ start_sz = res if not patch else 16
137
+ mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)],
138
+ self.mini_discs = nn.ModuleDict(mini_discs)
139
+
140
+ def forward(self, features, c):
141
+ all_logits = []
142
+ for k, disc in self.mini_discs.items():
143
+ res = disc(features[k], c).view(features[k].size(0), -1)
144
+ all_logits.append(res)
145
+
146
+ all_logits = torch.cat(all_logits, dim=1)
147
+ return all_logits
148
+
149
+
150
+ class ProjectedDiscriminator(torch.nn.Module):
151
+ def __init__(
152
+ self,
153
+ diffaug=True,
154
+ interp224=True,
155
+ backbone_kwargs={},
156
+ **kwargs
157
+ ):
158
+ super().__init__()
159
+ self.diffaug = diffaug
160
+ self.interp224 = interp224
161
+ self.feature_network = F_RandomProj(**backbone_kwargs)
162
+ self.discriminator = MultiScaleD(
163
+ channels=self.feature_network.CHANNELS,
164
+ resolutions=self.feature_network.RESOLUTIONS,
165
+ **backbone_kwargs,
166
+ )
167
+
168
+ def train(self, mode=True):
169
+ self.feature_network = self.feature_network.train(False)
170
+ self.discriminator = self.discriminator.train(mode)
171
+ return self
172
+
173
+ def eval(self):
174
+ return self.train(False)
175
+
176
+ def get_feature(self, x):
177
+ features = self.feature_network(x, get_features=True)
178
+ return features
179
+
180
+ def forward(self, x, c):
181
+ # if self.diffaug:
182
+ # x = DiffAugment(x, policy='color,translation,cutout')
183
+
184
+ # if self.interp224:
185
+ # x = F.interpolate(x, 224, mode='bilinear', align_corners=False)
186
+
187
+ features,backbone_features = self.feature_network(x)
188
+ logits = self.discriminator(features, c)
189
+
190
+ return logits,backbone_features
191
+
modules/layers/simswap/pg_modules/projector.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import timm
4
+ from modules.layers.simswap.pg_modules.blocks import FeatureFusionBlock
5
+
6
+
7
+ def _make_scratch_ccm(scratch, in_channels, cout, expand=False):
8
+ # shapes
9
+ out_channels = [cout, cout*2, cout*4, cout*8] if expand else [cout]*4
10
+
11
+ scratch.layer0_ccm = nn.Conv2d(in_channels[0], out_channels[0], kernel_size=1, stride=1, padding=0, bias=True)
12
+ scratch.layer1_ccm = nn.Conv2d(in_channels[1], out_channels[1], kernel_size=1, stride=1, padding=0, bias=True)
13
+ scratch.layer2_ccm = nn.Conv2d(in_channels[2], out_channels[2], kernel_size=1, stride=1, padding=0, bias=True)
14
+ scratch.layer3_ccm = nn.Conv2d(in_channels[3], out_channels[3], kernel_size=1, stride=1, padding=0, bias=True)
15
+
16
+ scratch.CHANNELS = out_channels
17
+
18
+ return scratch
19
+
20
+
21
+ def _make_scratch_csm(scratch, in_channels, cout, expand):
22
+ scratch.layer3_csm = FeatureFusionBlock(in_channels[3], nn.ReLU(False), expand=expand, lowest=True)
23
+ scratch.layer2_csm = FeatureFusionBlock(in_channels[2], nn.ReLU(False), expand=expand)
24
+ scratch.layer1_csm = FeatureFusionBlock(in_channels[1], nn.ReLU(False), expand=expand)
25
+ scratch.layer0_csm = FeatureFusionBlock(in_channels[0], nn.ReLU(False))
26
+
27
+ # last refinenet does not expand to save channels in higher dimensions
28
+ scratch.CHANNELS = [cout, cout, cout*2, cout*4] if expand else [cout]*4
29
+
30
+ return scratch
31
+
32
+
33
+ def _make_efficientnet(model):
34
+ pretrained = nn.Module()
35
+ pretrained.layer0 = nn.Sequential(model.conv_stem, model.bn1, model.act1, *model.blocks[0:2])
36
+ pretrained.layer1 = nn.Sequential(*model.blocks[2:3])
37
+ pretrained.layer2 = nn.Sequential(*model.blocks[3:5])
38
+ pretrained.layer3 = nn.Sequential(*model.blocks[5:9])
39
+ return pretrained
40
+
41
+
42
+ def calc_channels(pretrained, inp_res=224):
43
+ channels = []
44
+ tmp = torch.zeros(1, 3, inp_res, inp_res)
45
+
46
+ # forward pass
47
+ tmp = pretrained.layer0(tmp)
48
+ channels.append(tmp.shape[1])
49
+ tmp = pretrained.layer1(tmp)
50
+ channels.append(tmp.shape[1])
51
+ tmp = pretrained.layer2(tmp)
52
+ channels.append(tmp.shape[1])
53
+ tmp = pretrained.layer3(tmp)
54
+ channels.append(tmp.shape[1])
55
+
56
+ return channels
57
+
58
+
59
+ def _make_projector(im_res, cout, proj_type, expand=False):
60
+ assert proj_type in [0, 1, 2], "Invalid projection type"
61
+
62
+ ### Build pretrained feature network
63
+ model = timm.create_model('tf_efficientnet_lite0', pretrained=False,
64
+ checkpoint_path='/gavin/code/FaceSwapping/modules/third_party/efficientnet/'
65
+ 'tf_efficientnet_lite0-0aa007d2.pth')
66
+ pretrained = _make_efficientnet(model)
67
+
68
+ # determine resolution of feature maps, this is later used to calculate the number
69
+ # of down blocks in the discriminators. Interestingly, the best results are achieved
70
+ # by fixing this to 256, ie., we use the same number of down blocks per discriminator
71
+ # independent of the dataset resolution
72
+ im_res = 256
73
+ pretrained.RESOLUTIONS = [im_res//4, im_res//8, im_res//16, im_res//32]
74
+ pretrained.CHANNELS = calc_channels(pretrained)
75
+
76
+ if proj_type == 0: return pretrained, None
77
+
78
+ ### Build CCM
79
+ scratch = nn.Module()
80
+ scratch = _make_scratch_ccm(scratch, in_channels=pretrained.CHANNELS, cout=cout, expand=expand)
81
+ pretrained.CHANNELS = scratch.CHANNELS
82
+
83
+ if proj_type == 1: return pretrained, scratch
84
+
85
+ ### build CSM
86
+ scratch = _make_scratch_csm(scratch, in_channels=scratch.CHANNELS, cout=cout, expand=expand)
87
+
88
+ # CSM upsamples x2 so the feature map resolution doubles
89
+ pretrained.RESOLUTIONS = [res*2 for res in pretrained.RESOLUTIONS]
90
+ pretrained.CHANNELS = scratch.CHANNELS
91
+
92
+ return pretrained, scratch
93
+
94
+
95
+ class F_RandomProj(nn.Module):
96
+ def __init__(
97
+ self,
98
+ im_res=256,
99
+ cout=64,
100
+ expand=True,
101
+ proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing
102
+ **kwargs,
103
+ ):
104
+ super().__init__()
105
+ self.proj_type = proj_type
106
+ self.cout = cout
107
+ self.expand = expand
108
+
109
+ # build pretrained feature network and random decoder (scratch)
110
+ self.pretrained, self.scratch = _make_projector(im_res=im_res, cout=self.cout, proj_type=self.proj_type, expand=self.expand)
111
+ self.CHANNELS = self.pretrained.CHANNELS
112
+ self.RESOLUTIONS = self.pretrained.RESOLUTIONS
113
+
114
+ def forward(self, x, get_features=False):
115
+ # predict feature maps
116
+ out0 = self.pretrained.layer0(x)
117
+ out1 = self.pretrained.layer1(out0)
118
+ out2 = self.pretrained.layer2(out1)
119
+ out3 = self.pretrained.layer3(out2)
120
+
121
+ # start enumerating at the lowest layer (this is where we put the first discriminator)
122
+ backbone_features = {
123
+ '0': out0,
124
+ '1': out1,
125
+ '2': out2,
126
+ '3': out3,
127
+ }
128
+ if get_features:
129
+ return backbone_features
130
+
131
+ if self.proj_type == 0: return backbone_features
132
+
133
+ out0_channel_mixed = self.scratch.layer0_ccm(backbone_features['0'])
134
+ out1_channel_mixed = self.scratch.layer1_ccm(backbone_features['1'])
135
+ out2_channel_mixed = self.scratch.layer2_ccm(backbone_features['2'])
136
+ out3_channel_mixed = self.scratch.layer3_ccm(backbone_features['3'])
137
+
138
+ out = {
139
+ '0': out0_channel_mixed,
140
+ '1': out1_channel_mixed,
141
+ '2': out2_channel_mixed,
142
+ '3': out3_channel_mixed,
143
+ }
144
+
145
+ if self.proj_type == 1: return out
146
+
147
+ # from bottom to top
148
+ out3_scale_mixed = self.scratch.layer3_csm(out3_channel_mixed)
149
+ out2_scale_mixed = self.scratch.layer2_csm(out3_scale_mixed, out2_channel_mixed)
150
+ out1_scale_mixed = self.scratch.layer1_csm(out2_scale_mixed, out1_channel_mixed)
151
+ out0_scale_mixed = self.scratch.layer0_csm(out1_scale_mixed, out0_channel_mixed)
152
+
153
+ out = {
154
+ '0': out0_scale_mixed,
155
+ '1': out1_scale_mixed,
156
+ '2': out2_scale_mixed,
157
+ '3': out3_scale_mixed,
158
+ }
159
+
160
+ return out, backbone_features
modules/layers/smoothswap/id_embedder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from modules.layers.smoothswap.resnet import resnet50
6
+
7
+
8
+ class IdentityHead(nn.Module):
9
+ def __init__(self):
10
+ super(IdentityHead, self).__init__()
11
+ self.fc1 = nn.Sequential(
12
+ nn.Linear(512 * 4, 1024),
13
+ nn.BatchNorm1d(num_features=1024),
14
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)
15
+ )
16
+ self.fc2 = nn.Sequential(
17
+ nn.Linear(1024, 512),
18
+ nn.BatchNorm1d(num_features=512)
19
+ )
20
+
21
+ for m in self.modules():
22
+ if isinstance(m, (nn.BatchNorm2d,)):
23
+ nn.init.constant_(m.weight, 1)
24
+ nn.init.constant_(m.bias, 0)
25
+
26
+ def forward(self, x):
27
+ x = self.fc1(x)
28
+ x = self.fc2(x)
29
+ x = F.normalize(x)
30
+ return x
31
+
32
+
33
+ class IdentityEmbedder(nn.Module):
34
+ def __init__(self):
35
+ super(IdentityEmbedder, self).__init__()
36
+
37
+ self.backbone = resnet50(pretrained=False)
38
+ self.head = IdentityHead()
39
+
40
+ def forward(self, x_src):
41
+ x_src = self.backbone(x_src)
42
+ x_src = self.head(x_src)
43
+ return x_src
44
+
45
+
46
+ if __name__ == '__main__':
47
+ img = torch.randn((11, 3, 256, 256)).cuda()
48
+ net = IdentityEmbedder().cuda()
49
+ out = net(img)
50
+ print(out.shape)
modules/layers/smoothswap/resnet.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ # from .utils import load_state_dict_from_url
5
+
6
+
7
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
9
+ 'wide_resnet50_2', 'wide_resnet101_2']
10
+
11
+
12
+ model_urls = {
13
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
19
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
20
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
21
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
22
+ }
23
+
24
+
25
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
26
+ """3x3 convolution with padding"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
29
+
30
+
31
+ def conv1x1(in_planes, out_planes, stride=1):
32
+ """1x1 convolution"""
33
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
34
+
35
+
36
+ class BasicBlock(nn.Module):
37
+ expansion = 1
38
+
39
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
40
+ base_width=64, dilation=1, norm_layer=None):
41
+ super(BasicBlock, self).__init__()
42
+ if norm_layer is None:
43
+ norm_layer = nn.BatchNorm2d
44
+ if groups != 1 or base_width != 64:
45
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
46
+ if dilation > 1:
47
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
48
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
49
+ self.conv1 = conv3x3(inplanes, planes, stride)
50
+ self.bn1 = norm_layer(planes)
51
+ self.relu = nn.ReLU(inplace=True)
52
+ self.conv2 = conv3x3(planes, planes)
53
+ self.bn2 = norm_layer(planes)
54
+ self.downsample = downsample
55
+ self.stride = stride
56
+
57
+ def forward(self, x):
58
+ identity = x
59
+
60
+ out = self.conv1(x)
61
+ out = self.bn1(out)
62
+ out = self.relu(out)
63
+
64
+ out = self.conv2(out)
65
+ out = self.bn2(out)
66
+
67
+ if self.downsample is not None:
68
+ identity = self.downsample(x)
69
+
70
+ out += identity
71
+ out = self.relu(out)
72
+
73
+ return out
74
+
75
+
76
+ class Bottleneck(nn.Module):
77
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
78
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
79
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
80
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
81
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
82
+
83
+ expansion = 4
84
+
85
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
86
+ base_width=64, dilation=1, norm_layer=None):
87
+ super(Bottleneck, self).__init__()
88
+ if norm_layer is None:
89
+ norm_layer = nn.BatchNorm2d
90
+ width = int(planes * (base_width / 64.)) * groups
91
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
92
+ self.conv1 = conv1x1(inplanes, width)
93
+ self.bn1 = norm_layer(width)
94
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
95
+ self.bn2 = norm_layer(width)
96
+ self.conv3 = conv1x1(width, planes * self.expansion)
97
+ self.bn3 = norm_layer(planes * self.expansion)
98
+ self.relu = nn.ReLU(inplace=True)
99
+ self.downsample = downsample
100
+ self.stride = stride
101
+
102
+ def forward(self, x):
103
+ identity = x
104
+
105
+ out = self.conv1(x)
106
+ out = self.bn1(out)
107
+ out = self.relu(out)
108
+
109
+ out = self.conv2(out)
110
+ out = self.bn2(out)
111
+ out = self.relu(out)
112
+
113
+ out = self.conv3(out)
114
+ out = self.bn3(out)
115
+
116
+ if self.downsample is not None:
117
+ identity = self.downsample(x)
118
+
119
+ out += identity
120
+ out = self.relu(out)
121
+
122
+ return out
123
+
124
+
125
+ class ResNet(nn.Module):
126
+
127
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
128
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
129
+ norm_layer=None):
130
+ super(ResNet, self).__init__()
131
+ if norm_layer is None:
132
+ norm_layer = nn.BatchNorm2d
133
+ self._norm_layer = norm_layer
134
+
135
+ self.inplanes = 64
136
+ self.dilation = 1
137
+ if replace_stride_with_dilation is None:
138
+ # each element in the tuple indicates if we should replace
139
+ # the 2x2 stride with a dilated convolution instead
140
+ replace_stride_with_dilation = [False, False, False]
141
+ if len(replace_stride_with_dilation) != 3:
142
+ raise ValueError("replace_stride_with_dilation should be None "
143
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
144
+ self.groups = groups
145
+ self.base_width = width_per_group
146
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
147
+ bias=False)
148
+ self.bn1 = norm_layer(self.inplanes)
149
+ self.relu = nn.ReLU(inplace=True)
150
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
151
+ self.layer1 = self._make_layer(block, 64, layers[0])
152
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
153
+ dilate=replace_stride_with_dilation[0])
154
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
155
+ dilate=replace_stride_with_dilation[1])
156
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
157
+ dilate=replace_stride_with_dilation[2])
158
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
159
+
160
+ ''' head '''
161
+ # op1. vanilla ResNet
162
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
163
+
164
+ # op2. smooth-swap resnet
165
+ # FC is defined in id_embedder.py
166
+
167
+ for m in self.modules():
168
+ if isinstance(m, nn.Conv2d):
169
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
170
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
171
+ nn.init.constant_(m.weight, 1)
172
+ nn.init.constant_(m.bias, 0)
173
+
174
+ # Zero-initialize the last BN in each residual branch,
175
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
176
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
177
+ if zero_init_residual:
178
+ for m in self.modules():
179
+ if isinstance(m, Bottleneck):
180
+ nn.init.constant_(m.bn3.weight, 0)
181
+ elif isinstance(m, BasicBlock):
182
+ nn.init.constant_(m.bn2.weight, 0)
183
+
184
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
185
+ norm_layer = self._norm_layer
186
+ downsample = None
187
+ previous_dilation = self.dilation
188
+ if dilate:
189
+ self.dilation *= stride
190
+ stride = 1
191
+ if stride != 1 or self.inplanes != planes * block.expansion:
192
+ downsample = nn.Sequential(
193
+ conv1x1(self.inplanes, planes * block.expansion, stride),
194
+ norm_layer(planes * block.expansion),
195
+ )
196
+
197
+ layers = []
198
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
199
+ self.base_width, previous_dilation, norm_layer))
200
+ self.inplanes = planes * block.expansion
201
+ for _ in range(1, blocks):
202
+ layers.append(block(self.inplanes, planes, groups=self.groups,
203
+ base_width=self.base_width, dilation=self.dilation,
204
+ norm_layer=norm_layer))
205
+
206
+ return nn.Sequential(*layers)
207
+
208
+ def _forward_impl(self, x):
209
+ # See note [TorchScript super()]
210
+ x = self.conv1(x)
211
+ x = self.bn1(x)
212
+ x = self.relu(x)
213
+ x = self.maxpool(x)
214
+
215
+ x = self.layer1(x)
216
+ x = self.layer2(x)
217
+ x = self.layer3(x)
218
+ x = self.layer4(x)
219
+
220
+ x = self.avgpool(x)
221
+ x = torch.flatten(x, 1)
222
+
223
+ return x
224
+
225
+ def forward(self, x):
226
+ return self._forward_impl(x)
227
+
228
+
229
+ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
230
+ model = ResNet(block, layers, **kwargs)
231
+ if pretrained:
232
+ state_dict = load_state_dict_from_url(model_urls[arch],
233
+ progress=progress)
234
+ model.load_state_dict(state_dict)
235
+ return model
236
+
237
+
238
+ def resnet18(pretrained=False, progress=True, **kwargs):
239
+ r"""ResNet-18 model from
240
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
241
+
242
+ Args:
243
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
244
+ progress (bool): If True, displays a progress bar of the download to stderr
245
+ """
246
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
247
+ **kwargs)
248
+
249
+
250
+ def resnet34(pretrained=False, progress=True, **kwargs):
251
+ r"""ResNet-34 model from
252
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
253
+
254
+ Args:
255
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
256
+ progress (bool): If True, displays a progress bar of the download to stderr
257
+ """
258
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
259
+ **kwargs)
260
+
261
+
262
+ def resnet50(pretrained=False, progress=True, **kwargs):
263
+ r"""ResNet-50 model from
264
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
265
+
266
+ Args:
267
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
268
+ progress (bool): If True, displays a progress bar of the download to stderr
269
+ """
270
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
271
+ **kwargs)
272
+
273
+
274
+ def resnet101(pretrained=False, progress=True, **kwargs):
275
+ r"""ResNet-101 model from
276
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
277
+
278
+ Args:
279
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
280
+ progress (bool): If True, displays a progress bar of the download to stderr
281
+ """
282
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
283
+ **kwargs)
284
+
285
+
286
+ def resnet152(pretrained=False, progress=True, **kwargs):
287
+ r"""ResNet-152 model from
288
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
289
+
290
+ Args:
291
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
292
+ progress (bool): If True, displays a progress bar of the download to stderr
293
+ """
294
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
295
+ **kwargs)
296
+
297
+
298
+ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
299
+ r"""ResNeXt-50 32x4d model from
300
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
301
+
302
+ Args:
303
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
304
+ progress (bool): If True, displays a progress bar of the download to stderr
305
+ """
306
+ kwargs['groups'] = 32
307
+ kwargs['width_per_group'] = 4
308
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
309
+ pretrained, progress, **kwargs)
310
+
311
+
312
+ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
313
+ r"""ResNeXt-101 32x8d model from
314
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
315
+
316
+ Args:
317
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
318
+ progress (bool): If True, displays a progress bar of the download to stderr
319
+ """
320
+ kwargs['groups'] = 32
321
+ kwargs['width_per_group'] = 8
322
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
323
+ pretrained, progress, **kwargs)
324
+
325
+
326
+ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
327
+ r"""Wide ResNet-50-2 model from
328
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
329
+
330
+ The model is the same as ResNet except for the bottleneck number of channels
331
+ which is twice larger in every block. The number of channels in outer 1x1
332
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
333
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
334
+
335
+ Args:
336
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
337
+ progress (bool): If True, displays a progress bar of the download to stderr
338
+ """
339
+ kwargs['width_per_group'] = 64 * 2
340
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
341
+ pretrained, progress, **kwargs)
342
+
343
+
344
+ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
345
+ r"""Wide ResNet-101-2 model from
346
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
347
+
348
+ The model is the same as ResNet except for the bottleneck number of channels
349
+ which is twice larger in every block. The number of channels in outer 1x1
350
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
351
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
352
+
353
+ Args:
354
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
355
+ progress (bool): If True, displays a progress bar of the download to stderr
356
+ """
357
+ kwargs['width_per_group'] = 64 * 2
358
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
359
+ pretrained, progress, **kwargs)
modules/networks/faceshifter.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import os
6
+ import kornia
7
+ import warnings
8
+
9
+ from modules.layers.faceshifter.layers import AEI_Net
10
+ from modules.layers.faceshifter.hear_layers import Hear_Net
11
+ from third_party.arcface import iresnet100, MouthNet
12
+
13
+ make_abs_path = lambda fn: os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), fn))
14
+
15
+
16
+ class FSGenerator(nn.Module):
17
+ def __init__(self,
18
+ id_ckpt: str = None,
19
+ id_dim: int = 512,
20
+ mouth_net_param: dict = None,
21
+ in_size: int = 256,
22
+ finetune: bool = False,
23
+ downup: bool = False,
24
+ ):
25
+ super(FSGenerator, self).__init__()
26
+
27
+ ''' MouthNet '''
28
+ self.use_mouth_net = mouth_net_param.get('use')
29
+ self.mouth_feat_dim = 0
30
+ self.mouth_net = None
31
+ if self.use_mouth_net:
32
+ self.mouth_feat_dim = mouth_net_param.get('feature_dim')
33
+ self.mouth_crop_param = mouth_net_param.get('crop_param')
34
+ mouth_weight_path = make_abs_path(mouth_net_param.get('weight_path'))
35
+ self.mouth_net = MouthNet(
36
+ bisenet=None,
37
+ feature_dim=self.mouth_feat_dim,
38
+ crop_param=self.mouth_crop_param
39
+ )
40
+ self.mouth_net.load_backbone(mouth_weight_path)
41
+ print("[FaceShifter Generator] MouthNet loaded from %s" % mouth_weight_path)
42
+ self.mouth_net.eval()
43
+ self.mouth_net.requires_grad_(False)
44
+
45
+ self.G = AEI_Net(c_id=id_dim + self.mouth_feat_dim, finetune=finetune, downup=downup)
46
+ self.iresnet = iresnet100()
47
+ if not id_ckpt is None:
48
+ self.iresnet.load_state_dict(torch.load(id_ckpt, "cpu"))
49
+ else:
50
+ warnings.warn("Face ID backbone [%s] not found!" % id_ckpt)
51
+ raise FileNotFoundError("Face ID backbone [%s] not found!" % id_ckpt)
52
+ self.iresnet.eval()
53
+ self.register_buffer(
54
+ name="trans_matrix",
55
+ tensor=torch.tensor(
56
+ [
57
+ [
58
+ [1.07695457, -0.03625215, -1.56352194 * (in_size / 256)],
59
+ [0.03625215, 1.07695457, -5.32134629 * (in_size / 256)],
60
+ ]
61
+ ],
62
+ requires_grad=False,
63
+ ).float(),
64
+ )
65
+ self.in_size = in_size
66
+
67
+ self.iresnet.requires_grad_(False)
68
+
69
+ def forward(self, source, target, infer=False):
70
+ with torch.no_grad():
71
+ ''' 1. get id '''
72
+ if infer:
73
+ resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True)
74
+ id_vector = F.normalize(self.iresnet(resize_input), dim=-1, p=2)
75
+ else:
76
+ M = self.trans_matrix.repeat(source.size()[0], 1, 1)
77
+ source = kornia.geometry.transform.warp_affine(source, M, (self.in_size, self.in_size))
78
+
79
+ # import cv2
80
+ # from tricks import Trick
81
+ # cv2.imwrite('warpped_source.png', Trick.tensor_to_arr(source)[0, :, :, ::-1])
82
+
83
+ resize_input = F.interpolate(source, size=112, mode="bilinear", align_corners=True)
84
+ id_vector = F.normalize(self.iresnet(resize_input), dim=-1, p=2)
85
+
86
+ ''' 2. get mouth feature '''
87
+ if self.use_mouth_net:
88
+ w1, h1, w2, h2 = self.mouth_crop_param
89
+ mouth_input = resize_input[:, :, h1:h2, w1:w2] # 112->mouth
90
+ mouth_feat = self.mouth_net(mouth_input)
91
+ id_vector = torch.cat([id_vector, mouth_feat], dim=-1) # (B,dim_id+dim_mouth)
92
+
93
+ x, att = self.G(target, id_vector)
94
+ return x, id_vector, att
95
+
96
+ def get_recon(self):
97
+ return self.G.get_recon_tensor()
98
+
99
+ def get_att(self, x):
100
+ return self.G.get_attr(x)
101
+
102
+
103
+ class FSHearNet(nn.Module):
104
+ def __init__(self, aei_path: str):
105
+ super(FSHearNet, self).__init__()
106
+ ''' Stage I. AEI_Net '''
107
+ self.aei = FSGenerator(
108
+ id_ckpt=make_abs_path("../../modules/third_party/arcface/weights/ms1mv3_arcface_r100_fp16/backbone.pth")
109
+ ).requires_grad_(False)
110
+ print('Loading pre-trained AEI-Net from %s...' % aei_path)
111
+ self._load_pretrained_aei(aei_path)
112
+ print('Loaded.')
113
+
114
+ ''' Stage II. HEAR_Net '''
115
+ self.hear = Hear_Net()
116
+
117
+ def _load_pretrained_aei(self, path: str):
118
+ if '.ckpt' in path:
119
+ from trainer.faceshifter.extract_ckpt import extract_generator
120
+ pth_folder = make_abs_path('../../trainer/faceshifter/extracted_ckpt')
121
+ pth_name = 'hear_tmp.pth'
122
+ assert '.pth' in pth_name
123
+ state_dict = extract_generator(load_path=path, path=os.path.join(pth_folder, pth_name))
124
+ self.aei.load_state_dict(state_dict, strict=False)
125
+ self.aei.eval()
126
+ elif '.pth' in path:
127
+ self.aei.load_state_dict(torch.load(path, "cpu"), strict=False)
128
+ self.aei.eval()
129
+ else:
130
+ raise FileNotFoundError('%s (.ckpt or .pth) not found.' % path)
131
+
132
+ def forward(self, source, target):
133
+ with torch.no_grad():
134
+ y_hat_st, _, _ = self.aei(source, target, infer=True)
135
+ y_hat_tt, _, _ = self.aei(target, target, infer=True)
136
+ delta_y_t = target - y_hat_tt
137
+ y_cat = torch.cat([y_hat_st, delta_y_t], dim=1) # (B,6,256,256)
138
+
139
+ y_st = self.hear(y_cat)
140
+
141
+ return y_st, y_hat_st # both (B,3,256,256)
142
+
143
+
144
+ if __name__ == '__main__':
145
+
146
+ source = torch.randn(8, 3, 512, 512)
147
+ target = torch.randn(8, 3, 512, 512)
148
+ net = FSGenerator(
149
+ id_ckpt="/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/checkpoints/"
150
+ "face_id/ms1mv3_arcface_r100_fp16_backbone.pth",
151
+ mouth_net_param={
152
+ 'use': False
153
+ }
154
+ )
155
+ result, _, _ = net(source, target)
156
+ print('result:', result.shape)
157
+
158
+ # stage2 = FSHearNet(
159
+ # aei_path=make_abs_path("../../trainer/faceshifter/out/faceshifter_vanilla/epoch=32-step=509999.ckpt")
160
+ # )
161
+ # final_out, _ = stage2(source, target)
162
+ # print('final out:', final_out.shape)
modules/networks/simswap.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ #############################################################
4
+ # File: fs_model_fix_idnorm_donggp_saveoptim copy.py
5
+ # Created Date: Wednesday January 12th 2022
6
+ # Author: Chen Xuanhong
7
+ # Email: chenxuanhongzju@outlook.com
8
+ # Last Modified: Thursday, 21st April 2022 8:13:37 pm
9
+ # Modified By: Chen Xuanhong
10
+ # Copyright (c) 2022 Shanghai Jiao Tong University
11
+ #############################################################
12
+
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from modules.layers.simswap.base_model import BaseModel
18
+ from modules.layers.simswap.fs_networks_fix import Generator_Adain_Upsample
19
+
20
+ from modules.layers.simswap.pg_modules.projected_discriminator import ProjectedDiscriminator
21
+
22
+
23
+ def compute_grad2(d_out, x_in):
24
+ batch_size = x_in.size(0)
25
+ grad_dout = torch.autograd.grad(
26
+ outputs=d_out.sum(), inputs=x_in,
27
+ create_graph=True, retain_graph=True, only_inputs=True
28
+ )[0]
29
+ grad_dout2 = grad_dout.pow(2)
30
+ assert(grad_dout2.size() == x_in.size())
31
+ reg = grad_dout2.view(batch_size, -1).sum(1)
32
+ return reg
33
+
34
+
35
+ class fsModel(BaseModel):
36
+ def name(self):
37
+ return 'fsModel'
38
+
39
+ def initialize(self, opt):
40
+ BaseModel.initialize(self, opt)
41
+ # if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
42
+ self.isTrain = opt.isTrain
43
+
44
+ # Generator network
45
+ self.netG = Generator_Adain_Upsample(input_nc=3, output_nc=3, latent_size=512, n_blocks=9, deep=opt.Gdeep)
46
+ self.netG.cuda()
47
+
48
+ # Id network
49
+ from third_party.arcface import iresnet100
50
+ netArc_pth = "/apdcephfs_cq2/share_1290939/gavinyuan/code/FaceShifter/faceswap/faceswap/" \
51
+ "checkpoints/face_id/ms1mv3_arcface_r100_fp16_backbone.pth" #opt.Arc_path
52
+ self.netArc = iresnet100(pretrained=False, fp16=False)
53
+ self.netArc.load_state_dict(torch.load(netArc_pth, map_location="cpu"))
54
+ # netArc_checkpoint = opt.Arc_path
55
+ # netArc_checkpoint = torch.load(netArc_checkpoint, map_location=torch.device("cpu"))
56
+ # self.netArc = netArc_checkpoint['model'].module
57
+ self.netArc = self.netArc.cuda()
58
+ self.netArc.eval()
59
+ self.netArc.requires_grad_(False)
60
+ if not self.isTrain:
61
+ pretrained_path = opt.checkpoints_dir
62
+ self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
63
+ return
64
+ self.netD = ProjectedDiscriminator(diffaug=False, interp224=False, **{})
65
+ # self.netD.feature_network.requires_grad_(False)
66
+ self.netD.cuda()
67
+
68
+
69
+ if self.isTrain:
70
+ # define loss functions
71
+ self.criterionFeat = nn.L1Loss()
72
+ self.criterionRec = nn.L1Loss()
73
+
74
+ # initialize optimizers
75
+ # optimizer G
76
+ params = list(self.netG.parameters())
77
+ self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
78
+
79
+ # optimizer D
80
+ params = list(self.netD.parameters())
81
+ self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.99),eps=1e-8)
82
+
83
+ # load networks
84
+ if opt.continue_train:
85
+ pretrained_path = '' if not self.isTrain else opt.load_pretrain
86
+ # print (pretrained_path)
87
+ self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
88
+ self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
89
+ self.load_optim(self.optimizer_G, 'G', opt.which_epoch, pretrained_path)
90
+ self.load_optim(self.optimizer_D, 'D', opt.which_epoch, pretrained_path)
91
+ torch.cuda.empty_cache()
92
+
93
+ def cosin_metric(self, x1, x2):
94
+ #return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
95
+ return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
96
+
97
+ def save(self, which_epoch):
98
+ self.save_network(self.netG, 'G', which_epoch)
99
+ self.save_network(self.netD, 'D', which_epoch)
100
+ self.save_optim(self.optimizer_G, 'G', which_epoch)
101
+ self.save_optim(self.optimizer_D, 'D', which_epoch)
102
+ '''if self.gen_features:
103
+ self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
104
+
105
+ def update_fixed_params(self):
106
+ raise ValueError('Not used')
107
+ # after fixing the global generator for a number of iterations, also start finetuning it
108
+ params = list(self.netG.parameters())
109
+ if self.gen_features:
110
+ params += list(self.netE.parameters())
111
+ self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
112
+ if self.opt.verbose:
113
+ print('------------ Now also finetuning global generator -----------')
114
+
115
+ def update_learning_rate(self):
116
+ raise ValueError('Not used')
117
+ lrd = self.opt.lr / self.opt.niter_decay
118
+ lr = self.old_lr - lrd
119
+ for param_group in self.optimizer_D.param_groups:
120
+ param_group['lr'] = lr
121
+ for param_group in self.optimizer_G.param_groups:
122
+ param_group['lr'] = lr
123
+ if self.opt.verbose:
124
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
125
+ self.old_lr = lr
126
+
127
+
128
+ if __name__ == "__main__":
129
+ import os
130
+ import argparse
131
+
132
+ def str2bool(v):
133
+ return v.lower() in ('true')
134
+
135
+
136
+ class TrainOptions:
137
+ def __init__(self):
138
+ self.parser = argparse.ArgumentParser()
139
+ self.initialized = False
140
+
141
+ def initialize(self):
142
+ self.parser.add_argument('--name', type=str, default='simswap',
143
+ help='name of the experiment. It decides where to store samples and models')
144
+ self.parser.add_argument('--gpu_ids', default='0')
145
+ self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints',
146
+ help='models are saved here')
147
+ self.parser.add_argument('--isTrain', type=str2bool, default='True')
148
+
149
+ # input/output sizes
150
+ self.parser.add_argument('--batchSize', type=int, default=8, help='input batch size')
151
+
152
+ # for displays
153
+ self.parser.add_argument('--use_tensorboard', type=str2bool, default='False')
154
+
155
+ # for training
156
+ self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2",
157
+ help='path to the face swapping dataset')
158
+ self.parser.add_argument('--continue_train', type=str2bool, default='False',
159
+ help='continue training: load the latest model')
160
+ self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/simswap224_test',
161
+ help='load the pretrained model from the specified location')
162
+ self.parser.add_argument('--which_epoch', type=str, default='10000',
163
+ help='which epoch to load? set to latest to use latest cached model')
164
+ self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
165
+ self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate')
166
+ self.parser.add_argument('--niter_decay', type=int, default=10000,
167
+ help='# of iter to linearly decay learning rate to zero')
168
+ self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
169
+ self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam')
170
+ self.parser.add_argument('--Gdeep', type=str2bool, default='False')
171
+
172
+ # for discriminators
173
+ self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
174
+ self.parser.add_argument('--lambda_id', type=float, default=30.0, help='weight for id loss')
175
+ self.parser.add_argument('--lambda_rec', type=float, default=10.0, help='weight for reconstruction loss')
176
+
177
+ self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar',
178
+ help="run ONNX model via TRT")
179
+ self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
180
+ self.parser.add_argument("--log_frep", type=int, default=200, help='frequence for printing log information')
181
+ self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
182
+ self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model')
183
+
184
+ self.isTrain = True
185
+
186
+ def parse(self, save=True):
187
+ if not self.initialized:
188
+ self.initialize()
189
+ self.opt = self.parser.parse_args()
190
+ self.opt.isTrain = self.isTrain # train or test
191
+
192
+ args = vars(self.opt)
193
+
194
+ print('------------ Options -------------')
195
+ for k, v in sorted(args.items()):
196
+ print('%s: %s' % (str(k), str(v)))
197
+ print('-------------- End ----------------')
198
+
199
+ # save to the disk
200
+ # if self.opt.isTrain:
201
+ # expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
202
+ # util.mkdirs(expr_dir)
203
+ # if save and not self.opt.continue_train:
204
+ # file_name = os.path.join(expr_dir, 'opt.txt')
205
+ # with open(file_name, 'wt') as opt_file:
206
+ # opt_file.write('------------ Options -------------\n')
207
+ # for k, v in sorted(args.items()):
208
+ # opt_file.write('%s: %s\n' % (str(k), str(v)))
209
+ # opt_file.write('-------------- End ----------------\n')
210
+ return self.opt
211
+
212
+ source = torch.randn(8, 3, 256, 256).cuda()
213
+ target = torch.randn(8, 3, 256, 256).cuda()
214
+
215
+ opt = TrainOptions().parse()
216
+ model = fsModel()
217
+ model.initialize(opt)
218
+
219
+ import torch.nn.functional as F
220
+ img_id_112 = F.interpolate(source, size=(112, 112), mode='bicubic')
221
+ latent_id = model.netArc(img_id_112)
222
+ latent_id = F.normalize(latent_id, p=2, dim=1)
223
+
224
+ img_fake = model.netG(target, latent_id)
225
+ gen_logits, _ = model.netD(img_fake.detach(), None)
226
+ loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
227
+
228
+ real_logits, _ = model.netD(source, None)
229
+
230
+ print('img_fake:', img_fake.shape, 'real_logits:', real_logits.shape)
third_party/arcface/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from third_party.arcface.iresnet import iresnet18, iresnet34, iresnet50, iresnet100
2
+ from third_party.arcface.mouth_net import MouthNet
third_party/arcface/dataloaderx.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A copy from https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/dataset.py
3
+ """
4
+
5
+ import queue as Queue
6
+ import threading
7
+
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+
11
+
12
+ class BackgroundGenerator(threading.Thread):
13
+ def __init__(self, generator, local_rank, max_prefetch=6):
14
+ super(BackgroundGenerator, self).__init__()
15
+ self.queue = Queue.Queue(max_prefetch)
16
+ self.generator = generator
17
+ self.local_rank = local_rank
18
+ self.daemon = True
19
+ self.start()
20
+
21
+ def run(self):
22
+ torch.cuda.set_device(self.local_rank)
23
+ for item in self.generator:
24
+ self.queue.put(item)
25
+ self.queue.put(None)
26
+
27
+ def next(self):
28
+ next_item = self.queue.get()
29
+ if next_item is None:
30
+ raise StopIteration
31
+ return next_item
32
+
33
+ def __next__(self):
34
+ return self.next()
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class DataLoaderX(DataLoader):
41
+ def __init__(self, local_rank, **kwargs):
42
+ super(DataLoaderX, self).__init__(**kwargs)
43
+ self.stream = torch.cuda.Stream(local_rank)
44
+ self.local_rank = local_rank
45
+
46
+ def __iter__(self):
47
+ self.iter = super(DataLoaderX, self).__iter__()
48
+ self.iter = BackgroundGenerator(self.iter, self.local_rank)
49
+ self.preload()
50
+ return self
51
+
52
+ def preload(self):
53
+ self.batch = next(self.iter, None)
54
+ if self.batch is None:
55
+ return None
56
+ with torch.cuda.stream(self.stream):
57
+ for k in range(len(self.batch)):
58
+ self.batch[k] = self.batch[k].to(device=self.local_rank,
59
+ non_blocking=True)
60
+
61
+ def __next__(self):
62
+ torch.cuda.current_stream().wait_stream(self.stream)
63
+ batch = self.batch
64
+ if batch is None:
65
+ raise StopIteration
66
+ self.preload()
67
+ return batch
third_party/arcface/iresnet.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ __all__ = ["iresnet18", "iresnet34", "iresnet50", "iresnet100", "iresnet200"]
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
8
+ """3x3 convolution with padding"""
9
+ return nn.Conv2d(
10
+ in_planes,
11
+ out_planes,
12
+ kernel_size=3,
13
+ stride=stride,
14
+ padding=dilation,
15
+ groups=groups,
16
+ bias=False,
17
+ dilation=dilation,
18
+ )
19
+
20
+
21
+ def conv1x1(in_planes, out_planes, stride=1):
22
+ """1x1 convolution"""
23
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
24
+
25
+
26
+ class IBasicBlock(nn.Module):
27
+ expansion = 1
28
+
29
+ def __init__(
30
+ self,
31
+ inplanes,
32
+ planes,
33
+ stride=1,
34
+ downsample=None,
35
+ groups=1,
36
+ base_width=64,
37
+ dilation=1,
38
+ ):
39
+ super(IBasicBlock, self).__init__()
40
+ if groups != 1 or base_width != 64:
41
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
42
+ if dilation > 1:
43
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
44
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
45
+ self.conv1 = conv3x3(inplanes, planes)
46
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
47
+ self.prelu = nn.PReLU(planes)
48
+ self.conv2 = conv3x3(planes, planes, stride)
49
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
50
+ self.downsample = downsample
51
+ self.stride = stride
52
+
53
+ def forward(self, x):
54
+ identity = x
55
+ out = self.bn1(x)
56
+ out = self.conv1(out)
57
+ out = self.bn2(out)
58
+ out = self.prelu(out)
59
+ out = self.conv2(out)
60
+ out = self.bn3(out)
61
+ if self.downsample is not None:
62
+ identity = self.downsample(x)
63
+ out += identity
64
+ return out
65
+
66
+
67
+ class IResNet(nn.Module):
68
+ def __init__(
69
+ self,
70
+ block,
71
+ layers,
72
+ dropout=0,
73
+ num_features=512,
74
+ zero_init_residual=False,
75
+ groups=1,
76
+ width_per_group=64,
77
+ replace_stride_with_dilation=None,
78
+ fp16=False,
79
+ fc_scale = 7 * 7,
80
+ ):
81
+ super(IResNet, self).__init__()
82
+ self.fp16 = fp16
83
+ self.inplanes = 64
84
+ self.dilation = 1
85
+ self.fc_scale = fc_scale
86
+ if replace_stride_with_dilation is None:
87
+ replace_stride_with_dilation = [False, False, False]
88
+ if len(replace_stride_with_dilation) != 3:
89
+ raise ValueError(
90
+ "replace_stride_with_dilation should be None "
91
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
92
+ )
93
+ self.groups = groups
94
+ self.base_width = width_per_group
95
+ self.conv1 = nn.Conv2d(
96
+ 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
97
+ )
98
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
99
+ self.prelu = nn.PReLU(self.inplanes)
100
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
101
+ self.layer2 = self._make_layer(
102
+ block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
103
+ )
104
+ self.layer3 = self._make_layer(
105
+ block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
106
+ )
107
+ self.layer4 = self._make_layer(
108
+ block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
109
+ )
110
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
111
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
112
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
113
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
114
+ nn.init.constant_(self.features.weight, 1.0)
115
+ self.features.weight.requires_grad = False
116
+
117
+ for m in self.modules():
118
+ if isinstance(m, nn.Conv2d):
119
+ nn.init.normal_(m.weight, 0, 0.1)
120
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
121
+ nn.init.constant_(m.weight, 1)
122
+ nn.init.constant_(m.bias, 0)
123
+
124
+ if zero_init_residual:
125
+ for m in self.modules():
126
+ if isinstance(m, IBasicBlock):
127
+ nn.init.constant_(m.bn2.weight, 0)
128
+
129
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
130
+ downsample = None
131
+ previous_dilation = self.dilation
132
+ if dilate:
133
+ self.dilation *= stride
134
+ stride = 1
135
+ if stride != 1 or self.inplanes != planes * block.expansion:
136
+ downsample = nn.Sequential(
137
+ conv1x1(self.inplanes, planes * block.expansion, stride),
138
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05,),
139
+ )
140
+ layers = []
141
+ layers.append(
142
+ block(
143
+ self.inplanes,
144
+ planes,
145
+ stride,
146
+ downsample,
147
+ self.groups,
148
+ self.base_width,
149
+ previous_dilation,
150
+ )
151
+ )
152
+ self.inplanes = planes * block.expansion
153
+ for _ in range(1, blocks):
154
+ layers.append(
155
+ block(
156
+ self.inplanes,
157
+ planes,
158
+ groups=self.groups,
159
+ base_width=self.base_width,
160
+ dilation=self.dilation,
161
+ )
162
+ )
163
+
164
+ return nn.Sequential(*layers)
165
+
166
+ def forward(self, x):
167
+ with torch.cuda.amp.autocast(self.fp16):
168
+ x = self.conv1(x)
169
+ x = self.bn1(x)
170
+ x = self.prelu(x)
171
+ x = self.layer1(x)
172
+ x = self.layer2(x)
173
+ x = self.layer3(x)
174
+ x = self.layer4(x)
175
+ x = self.bn2(x)
176
+ # print(x.shape)
177
+ x = torch.flatten(x, 1)
178
+ x = self.dropout(x)
179
+ x = self.fc(x.float() if self.fp16 else x)
180
+ x = self.features(x)
181
+ return x
182
+
183
+
184
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
185
+ model = IResNet(block, layers, **kwargs)
186
+ if pretrained:
187
+ model_dir = {
188
+ 'iresnet18': './weights/r18-backbone.pth',
189
+ 'iresnet34': './weights/r34-backbone.pth',
190
+ 'iresnet50': './weights/r50-backbone.pth',
191
+ 'iresnet100': './weights/r100-backbone.pth',
192
+ }
193
+ pre_trained_weights = torch.load(model_dir[arch], map_location=torch.device('cpu'))
194
+
195
+ tmp_dict = {}
196
+ for key in pre_trained_weights:
197
+ # if 'features' in key or 'fc' in key:
198
+ # print('skip %s' % key)
199
+ # continue
200
+ tmp_dict[key] = pre_trained_weights[key]
201
+
202
+ # get 'iresnet' model layers which don't exist in 'arcxx' and insert to tmp
203
+ model_dict = model.state_dict()
204
+ for key in model_dict:
205
+ if key not in tmp_dict:
206
+ tmp_dict[key] = model_dict[key]
207
+
208
+ model.load_state_dict(tmp_dict, strict=False)
209
+ print("load pre-trained iresnet from %s" % model_dir[arch])
210
+
211
+ return model
212
+
213
+
214
+ def iresnet18(pretrained=False, progress=True, **kwargs):
215
+ return _iresnet(
216
+ "iresnet18", IBasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs
217
+ )
218
+
219
+
220
+ def iresnet34(pretrained=False, progress=True, **kwargs):
221
+ return _iresnet(
222
+ "iresnet34", IBasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs
223
+ )
224
+
225
+
226
+ def iresnet50(pretrained=False, progress=True, **kwargs):
227
+ return _iresnet(
228
+ "iresnet50", IBasicBlock, [3, 4, 14, 3], pretrained, progress, **kwargs
229
+ )
230
+
231
+
232
+ def iresnet100(pretrained=False, progress=True, **kwargs):
233
+ return _iresnet(
234
+ "iresnet100", IBasicBlock, [3, 13, 30, 3], pretrained, progress, **kwargs
235
+ )
236
+
237
+
238
+ def iresnet200(pretrained=False, progress=True, **kwargs):
239
+ return _iresnet(
240
+ "iresnet200", IBasicBlock, [6, 26, 60, 6], pretrained, progress, **kwargs
241
+ )
242
+
243
+
244
+ @torch.no_grad()
245
+ def identification(folder: str = './images', target_idx: int = 0):
246
+ import os
247
+ from PIL import Image
248
+ import torch
249
+ import torchvision.transforms as transforms
250
+ import torch.nn.functional as F
251
+ import kornia
252
+ import numpy as np
253
+
254
+ os.makedirs('crop', exist_ok=True)
255
+ img_list = os.listdir(folder)
256
+ img_list.sort()
257
+ n = len(img_list)
258
+ trans = transforms.Compose([
259
+ transforms.Resize(256),
260
+ transforms.CenterCrop(224),
261
+ transforms.ToTensor(),
262
+ # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
263
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
264
+ ])
265
+ trans_matrix = torch.tensor(
266
+ [[[1.07695457, -0.03625215, -1.56352194],
267
+ [0.03625215, 1.07695457, -5.32134629]]],
268
+ requires_grad=False).float().cuda()
269
+
270
+ fid_model = iresnet50(pretrained=True).cuda().eval()
271
+
272
+ def save_tensor_to_img(tensor: torch.Tensor, path: str, scale=255):
273
+ tensor = tensor.permute(0, 2, 3, 1)[0] # in [0,1]
274
+ tensor = tensor.clamp(0, 1)
275
+ tensor = tensor * scale
276
+ tensor_np = tensor.cpu().numpy().astype(np.uint8)
277
+ if tensor_np.shape[-1] == 1: # channel dim
278
+ tensor_np = tensor_np.repeat(3, axis=-1)
279
+ tensor_img = Image.fromarray(tensor_np)
280
+ tensor_img.save(path)
281
+
282
+ feats = torch.zeros((n, 512), dtype=torch.float32).cuda()
283
+ for idx, img_path in enumerate(img_list):
284
+ img_pil = Image.open(os.path.join(folder, img_path)).convert('RGB')
285
+ img_tensor = trans(img_pil).unsqueeze(0).cuda()
286
+
287
+ # img_tensor = kornia.geometry.transform.warp_affine(img_tensor, trans_matrix, (256, 256))
288
+ save_tensor_to_img(img_tensor / 2 + 0.5, path=os.path.join('./crop', img_path))
289
+ img_tensor = F.interpolate(img_tensor, size=112, mode="bilinear", align_corners=True) # to 112
290
+
291
+ feat = fid_model(img_tensor)
292
+ feats[idx] = feat
293
+
294
+ target_feat = feats[target_idx].unsqueeze(0)
295
+ cosine_sim = F.cosine_similarity(target_feat, feats, 1)
296
+ print(cosine_sim.shape)
297
+
298
+ print('====== similarity with %s ======' % img_list[target_idx])
299
+ for idx in range(n):
300
+ print('[%d] %s = %.2f' % (idx, img_list[idx], float(cosine_sim[idx].cpu())))
301
+
302
+
303
+ if __name__ == '__main__':
304
+ import argparse
305
+
306
+ parser = argparse.ArgumentParser(description="arcface")
307
+ parser.add_argument("-i", "--target_idx", type=int, default=0)
308
+ args = parser.parse_args()
309
+
310
+ identification(target_idx=args.target_idx)
311
+
third_party/arcface/load_dataset.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numbers
3
+
4
+ import torch
5
+ import mxnet as mx
6
+ from PIL import Image
7
+ from torch.utils import data
8
+ from torchvision import transforms
9
+
10
+ import numpy as np
11
+ import PIL.Image as Image
12
+
13
+
14
+ """ Original mxnet dataset
15
+ """
16
+ class MXFaceDataset(data.Dataset):
17
+ def __init__(self, root_dir, crop_param=(0, 0, 112, 112)):
18
+ super(MXFaceDataset, self,).__init__()
19
+ self.transform = transforms.Compose([
20
+ # transforms.ToPILImage(),
21
+ transforms.RandomHorizontalFlip(),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
24
+ ])
25
+ self.root_dir = root_dir
26
+ self.crop_param = crop_param
27
+ path_imgrec = os.path.join(root_dir, 'train.rec')
28
+ path_imgidx = os.path.join(root_dir, 'train.idx')
29
+ self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
30
+ s = self.imgrec.read_idx(0)
31
+ header, _ = mx.recordio.unpack(s)
32
+ if header.flag > 0:
33
+ self.header0 = (int(header.label[0]), int(header.label[1]))
34
+ self.imgidx = np.array(range(1, int(header.label[0])))
35
+ else:
36
+ self.imgidx = np.array(list(self.imgrec.keys))
37
+
38
+ def __getitem__(self, index):
39
+ idx = self.imgidx[index]
40
+ s = self.imgrec.read_idx(idx)
41
+ header, img = mx.recordio.unpack(s)
42
+ label = header.label
43
+ if not isinstance(label, numbers.Number):
44
+ label = label[0]
45
+ label = torch.tensor(label, dtype=torch.long)
46
+ sample = mx.image.imdecode(img).asnumpy()
47
+ if self.transform is not None:
48
+ sample: Image = transforms.ToPILImage()(sample)
49
+ sample = sample.crop(self.crop_param)
50
+ sample = self.transform(sample)
51
+ return sample, label
52
+
53
+ def __len__(self):
54
+ return len(self.imgidx)
55
+
56
+
57
+ """ MXNet binary dataset reader.
58
+ Refer to https://github.com/deepinsight/insightface.
59
+ """
60
+ import pickle
61
+ from typing import List
62
+ from mxnet import ndarray as nd
63
+ class ReadMXNet(object):
64
+ def __init__(self, val_targets, rec_prefix, image_size=(112, 112)):
65
+ self.ver_list: List[object] = []
66
+ self.ver_name_list: List[str] = []
67
+ self.rec_prefix = rec_prefix
68
+ self.val_targets = val_targets
69
+
70
+ def init_dataset(self, val_targets, data_dir, image_size):
71
+ for name in val_targets:
72
+ path = os.path.join(data_dir, name + ".bin")
73
+ if os.path.exists(path):
74
+ data_set = self.load_bin(path, image_size)
75
+ self.ver_list.append(data_set)
76
+ self.ver_name_list.append(name)
77
+
78
+ def load_bin(self, path, image_size):
79
+ try:
80
+ with open(path, 'rb') as f:
81
+ bins, issame_list = pickle.load(f) # py2
82
+ except UnicodeDecodeError as e:
83
+ with open(path, 'rb') as f:
84
+ bins, issame_list = pickle.load(f, encoding='bytes') # py3
85
+ data_list = []
86
+ # for flip in [0, 1]:
87
+ # data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
88
+ # data_list.append(data)
89
+ for idx in range(len(issame_list) * 2):
90
+ _bin = bins[idx]
91
+ img = mx.image.imdecode(_bin)
92
+ if img.shape[1] != image_size[0]:
93
+ img = mx.image.resize_short(img, image_size[0])
94
+ img = nd.transpose(img, axes=(2, 0, 1)) # (C, H, W)
95
+
96
+ img = nd.transpose(img, axes=(1, 2, 0)) # (H, W, C)
97
+ import PIL.Image as Image
98
+ fig = Image.fromarray(img.asnumpy(), mode='RGB')
99
+ data_list.append(fig)
100
+ # data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
101
+ if idx % 1000 == 0:
102
+ print('loading bin', idx)
103
+
104
+ # # save img to '/home/yuange/dataset/LFW/rgb-arcface'
105
+ # img = nd.transpose(img, axes=(1, 2, 0)) # (H, W, C)
106
+ # # save_name = 'ind_' + str(idx) + '.bmp'
107
+ # # import os
108
+ # # save_name = os.path.join('/home/yuange/dataset/LFW/rgb-arcface', save_name)
109
+ # import PIL.Image as Image
110
+ # fig = Image.fromarray(img.asnumpy(), mode='RGB')
111
+ # # fig.save(save_name)
112
+
113
+ print('load finished', len(data_list))
114
+ return data_list, issame_list
115
+
116
+
117
+ """
118
+ Evaluation Benchmark
119
+ """
120
+ class EvalDataset(data.Dataset):
121
+ def __init__(self,
122
+ target: str = 'lfw',
123
+ rec_folder: str = '',
124
+ transform = None,
125
+ crop_param = (0, 0, 112, 112)
126
+ ):
127
+ print("=> Pre-loading images ...")
128
+ self.target = target
129
+ self.rec_folder = rec_folder
130
+ mx_reader = ReadMXNet(target, rec_folder)
131
+ path = os.path.join(rec_folder, target + ".bin")
132
+ all_img, issame_list = mx_reader.load_bin(path, (112, 112))
133
+ self.all_img = all_img
134
+ self.issame_list = []
135
+ for i in range(len(issame_list)):
136
+ flag = 0 if issame_list[i] else 1 # 0:is same
137
+ self.issame_list.append(flag)
138
+
139
+ self.transform = transform
140
+ if self.transform is None:
141
+ self.transform = transforms.Compose([
142
+ transforms.ToTensor(),
143
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
144
+ ])
145
+ self.crop_param = crop_param
146
+
147
+ def __getitem__(self, index):
148
+ img1 = self.all_img[index * 2]
149
+ img2 = self.all_img[index * 2 + 1]
150
+ same = self.issame_list[index]
151
+
152
+ save_index = 11
153
+ if index == save_index:
154
+ img1.save('img1_ori.jpg')
155
+ img2.save('img2_ori.jpg')
156
+
157
+ img1 = img1.crop(self.crop_param)
158
+ img2 = img2.crop(self.crop_param)
159
+ if index == save_index:
160
+ img1.save('img1_crop.jpg')
161
+ img2.save('img2_crop.jpg')
162
+
163
+ img1 = self.transform(img1)
164
+ img2 = self.transform(img2)
165
+
166
+ return img1, img2, same
167
+
168
+ def __len__(self):
169
+ return len(self.issame_list)
170
+
171
+
172
+ if __name__ == '__main__':
173
+
174
+ import PIL.Image as Image
175
+ import time
176
+
177
+ np.random.seed(1)
178
+ torch.manual_seed(1)
179
+ torch.cuda.manual_seed(1)
180
+ torch.cuda.manual_seed_all(1)
181
+ mx.random.seed(1)
182
+
183
+ is_gray = False
184
+
185
+ train_set = FaceByRandOccMask(
186
+ root_dir='/tmp/train_tmp/casia',
187
+ local_rank=0,
188
+ use_norm=True,
189
+ is_gray=is_gray,
190
+ )
191
+ start = time.time()
192
+ for idx in range(100):
193
+ face, mask, label = train_set.__getitem__(idx)
194
+ if idx < 15:
195
+ face = ((face + 1) * 128).numpy().astype(np.uint8)
196
+ face = np.transpose(face, (1, 2, 0))
197
+ if is_gray:
198
+ face = Image.fromarray(face[:, :, 0], mode='L')
199
+ else:
200
+ face = Image.fromarray(face, mode='RGB')
201
+ face.save('face_{}.jpg'.format(idx))
202
+ print('time cost: %d ms' % (int((time.time() - start) * 1000)))
third_party/arcface/margin_loss.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from torch.nn import Parameter
6
+
7
+ import numpy as np
8
+
9
+ __all__ = ['Softmax', 'AMCosFace', 'AMArcFace', ]
10
+
11
+
12
+ MIN_NUM_PATCHES = 16
13
+
14
+
15
+ """ All losses can run in 'torch.distributed.DistributedDataParallel'.
16
+ """
17
+
18
+ class Softmax(nn.Module):
19
+ r"""Implementation of Softmax (normal classification head):
20
+ Args:
21
+ in_features: dimension (d_in) of input feature (B, d_in)
22
+ out_features: dimension (d_out) of output feature (B, d_out)
23
+ device_id: the ID of GPU where the model will be trained by data parallel (or DP). (not used)
24
+ if device_id=None, it will be trained on model parallel (or DDP). (recommend!)
25
+ """
26
+ def __init__(self,
27
+ in_features: int,
28
+ out_features: int,
29
+ device_id,
30
+ ):
31
+ super(Softmax, self).__init__()
32
+ self.in_features = in_features
33
+ self.out_features = out_features
34
+ self.device_id = device_id
35
+
36
+ self.weight = Parameter(torch.FloatTensor(out_features, in_features))
37
+ self.bias = Parameter(torch.FloatTensor(out_features))
38
+ nn.init.xavier_uniform_(self.weight)
39
+ nn.init.zeros_(self.bias)
40
+
41
+ def forward(self, embedding, label):
42
+ """
43
+ :param embedding: learned face representation
44
+ :param label:
45
+ - label >= 0: ground truth identity
46
+ - label = -1: invalid identity for this GPU (refer to 'PartialFC')
47
+ + Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
48
+ :return:
49
+ """
50
+ if self.device_id is None:
51
+ """ Regular linear layer.
52
+ """
53
+ out = F.linear(embedding, self.weight, self.bias)
54
+ else:
55
+ raise ValueError('DataParallel is not implemented yet.')
56
+ x = input
57
+ sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
58
+ sub_biases = torch.chunk(self.bias, len(self.device_id), dim=0)
59
+ temp_x = x.cuda(self.device_id[0])
60
+ weight = sub_weights[0].cuda(self.device_id[0])
61
+ bias = sub_biases[0].cuda(self.device_id[0])
62
+ out = F.linear(temp_x, weight, bias)
63
+ for i in range(1, len(self.device_id)):
64
+ temp_x = x.cuda(self.device_id[i])
65
+ weight = sub_weights[i].cuda(self.device_id[i])
66
+ bias = sub_biases[i].cuda(self.device_id[i])
67
+ out = torch.cat((out, F.linear(temp_x, weight, bias).cuda(self.device_id[0])), dim=1)
68
+ return out
69
+
70
+
71
+ """ Not Used """
72
+ class ArcFace(nn.Module):
73
+ r"""Implement of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
74
+ Args:
75
+ in_features: size of each input sample
76
+ out_features: size of each output sample
77
+ device_id: the ID of GPU where the model will be trained by model parallel.
78
+ if device_id=None, it will be trained on CPU without model parallel.
79
+ s: norm of input feature
80
+ m: margin
81
+ cos(theta+m)
82
+ """
83
+
84
+ def __init__(self, in_features, out_features, device_id, s=64.0, m=0.50, easy_margin=False):
85
+ super(ArcFace, self).__init__()
86
+ self.in_features = in_features
87
+ self.out_features = out_features
88
+ self.device_id = device_id
89
+
90
+ self.s = s
91
+ self.m = m
92
+ print('ArcFace, s=%.1f, m=%.2f' % (s, m))
93
+
94
+ self.weight = Parameter(torch.FloatTensor(out_features, in_features))
95
+ nn.init.xavier_uniform_(self.weight)
96
+
97
+ self.easy_margin = easy_margin
98
+ self.cos_m = np.cos(m)
99
+ self.sin_m = np.sin(m)
100
+ self.th = np.cos(np.pi - m)
101
+ self.mm = np.sin(np.pi - m) * m
102
+
103
+ def forward(self, input, label):
104
+ # --------------------------- cos(theta) & phi(theta) ---------------------------
105
+ if self.device_id == None:
106
+ cosine = F.linear(F.normalize(input), F.normalize(self.weight))
107
+ else:
108
+ x = input
109
+ sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
110
+ temp_x = x.cuda(self.device_id[0])
111
+ weight = sub_weights[0].cuda(self.device_id[0])
112
+ cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
113
+ for i in range(1, len(self.device_id)):
114
+ temp_x = x.cuda(self.device_id[i])
115
+ weight = sub_weights[i].cuda(self.device_id[i])
116
+ cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])),
117
+ dim=1)
118
+ sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
119
+ phi = cosine * self.cos_m - sine * self.sin_m
120
+ if self.easy_margin:
121
+ phi = torch.where(cosine > 0, phi, cosine)
122
+ else:
123
+ phi = torch.where(cosine > self.th, phi, cosine - self.mm)
124
+ # --------------------------- convert label to one-hot ---------------------------
125
+ one_hot = torch.zeros(cosine.size())
126
+ if self.device_id != None:
127
+ one_hot = one_hot.cuda(self.device_id[0])
128
+ else:
129
+ one_hot = one_hot.cuda()
130
+ one_hot.scatter_(1, label.view(-1, 1).long(), 1)
131
+ # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
132
+ output = (one_hot * phi) + (
133
+ (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
134
+ output *= self.s
135
+
136
+ return output
137
+
138
+
139
+ """ Not Used """
140
+ class CosFace(nn.Module):
141
+ r"""Implement of CosFace (https://arxiv.org/pdf/1801.09414.pdf):
142
+ Args:
143
+ in_features: size of each input sample
144
+ out_features: size of each output sample
145
+ device_id: the ID of GPU where the model will be trained by model parallel.
146
+ if device_id=None, it will be trained on CPU without model parallel.
147
+ s: norm of input feature
148
+ m: margin
149
+ cos(theta)-m
150
+ """
151
+
152
+ def __init__(self, in_features, out_features, device_id, s=64.0, m=0.4):
153
+ super(CosFace, self).__init__()
154
+ print('CosFace, s=%.1f, m=%.2f' % (s, m))
155
+ self.in_features = in_features
156
+ self.out_features = out_features
157
+ self.device_id = device_id
158
+ self.s = s
159
+ self.m = m
160
+ self.weight = Parameter(torch.FloatTensor(out_features, in_features))
161
+ nn.init.xavier_uniform_(self.weight)
162
+
163
+ def forward(self, input, label):
164
+ # --------------------------- cos(theta) & phi(theta) ---------------------------
165
+
166
+ if self.device_id == None:
167
+ cosine = F.linear(F.normalize(input), F.normalize(self.weight))
168
+ else:
169
+ x = input
170
+ sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
171
+ temp_x = x.cuda(self.device_id[0])
172
+ weight = sub_weights[0].cuda(self.device_id[0])
173
+ cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
174
+ for i in range(1, len(self.device_id)):
175
+ temp_x = x.cuda(self.device_id[i])
176
+ weight = sub_weights[i].cuda(self.device_id[i])
177
+ cosine = torch.cat((cosine, F.linear(F.normalize(temp_x), F.normalize(weight)).cuda(self.device_id[0])),
178
+ dim=1)
179
+ phi = cosine - self.m
180
+ # --------------------------- convert label to one-hot ---------------------------
181
+ one_hot = torch.zeros(cosine.size()).cuda()
182
+ if self.device_id != None:
183
+ one_hot = one_hot.cuda(self.device_id[0])
184
+ # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
185
+ one_hot.scatter_(1, label.cuda(self.device_id[0]).view(-1, 1).long(), 1)
186
+ else:
187
+ one_hot.scatter_(1, label.view(-1, 1).long(), 1)
188
+ # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
189
+ output = (one_hot * phi) + (
190
+ (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4
191
+ output *= self.s
192
+
193
+ return output
194
+
195
+ def __repr__(self):
196
+ return self.__class__.__name__ + '(' \
197
+ + 'in_features = ' + str(self.in_features) \
198
+ + ', out_features = ' + str(self.out_features) \
199
+ + ', s = ' + str(self.s) \
200
+ + ', m = ' + str(self.m) + ')'
201
+
202
+
203
+ class AMCosFace(nn.Module):
204
+ r"""Implementation of Adaptive Margin CosFace:
205
+ cos(theta)-m+k(theta-a)
206
+ When k is 0, AMCosFace degenerates into CosFace.
207
+ Args:
208
+ in_features: dimension (d_in) of input feature (B, d_in)
209
+ out_features: dimension (d_out) of output feature (B, d_out)
210
+ device_id: the ID of GPU where the model will be trained by data parallel (or DP). (not used)
211
+ if device_id=None, it will be trained on model parallel (or DDP). (recommend!)
212
+ s: norm of input feature
213
+ m: margin
214
+ a: AM Loss
215
+ k: AM Loss
216
+ """
217
+ def __init__(self,
218
+ in_features: int,
219
+ out_features: int,
220
+ device_id,
221
+ s: float = 64.0,
222
+ m: float = 0.4,
223
+ a: float = 1.2,
224
+ k: float = 0.1,
225
+ ):
226
+ super(AMCosFace, self).__init__()
227
+ print('AMCosFace, s=%.1f, m=%.2f, a=%.2f, k=%.2f' % (s, m, a, k))
228
+ self.in_features = in_features
229
+ self.out_features = out_features
230
+ self.device_id = device_id
231
+
232
+ self.s = s
233
+ self.m = m
234
+ self.a = a
235
+ self.k = k
236
+
237
+ """ Weight Matrix W (d_out, d_in) """
238
+ self.weight = Parameter(torch.FloatTensor(out_features, in_features))
239
+ nn.init.xavier_uniform_(self.weight)
240
+
241
+ def forward(self, embedding, label):
242
+ """
243
+ :param embedding: learned face representation
244
+ :param label:
245
+ - label >= 0: ground truth identity
246
+ - label = -1: invalid identity for this GPU (refer to 'PartialFC')
247
+ + Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
248
+ :return:
249
+ """
250
+ if self.device_id is None:
251
+ """ - embedding: shape is (B, d_in)
252
+ - weight: shape is (d_out, d_in)
253
+ - cosine: shape is (B, d_out)
254
+ + F.normalize is very important here.
255
+ """
256
+ cosine = F.linear(F.normalize(embedding), F.normalize(self.weight)) # y = xA^T + b
257
+ else:
258
+ raise ValueError('DataParallel is not implemented yet.')
259
+ x = input
260
+ sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
261
+ temp_x = x.cuda(self.device_id[0])
262
+ weight = sub_weights[0].cuda(self.device_id[0])
263
+ cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
264
+ for i in range(1, len(self.device_id)):
265
+ temp_x = x.cuda(self.device_id[i])
266
+ weight = sub_weights[i].cuda(self.device_id[i])
267
+ cosine = torch.cat((cosine, F.linear(F.normalize(temp_x),
268
+ F.normalize(weight)).cuda(self.device_id[0])),
269
+ dim=1)
270
+
271
+ """ - index: the index of valid identity in label, shape is (d_valid, )
272
+ + torch.where() returns a tuple indicating the index of each dimension
273
+ + Example: index = torch.tensor([1, 3, 4])
274
+ """
275
+ index = torch.where(label != -1)[0]
276
+
277
+ """ - m_hot: one-hot tensor of margin m_2, shape is (d_valid, d_out)
278
+ + torch.tensor.scatter_(dim, index, source) is usually used to generate ont-hot tensor
279
+ + Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
280
+ index = torch.tensor([1, 3, 4]) # d_valid = index.shape[0] = 3
281
+ m_hot = torch.tensor([[0, 0, 0, 0, m, 0],
282
+ [0, 0, 0, 0, 0, m],
283
+ [0, 0, 0, m, 0, 0],
284
+ ])
285
+ """
286
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
287
+ m_hot.scatter_(1, label[index, None], self.m)
288
+
289
+ """ logit(theta) = cos(theta) - m_2 + k * (theta - a)
290
+ - theta = cosine.acos_()
291
+ + Example: m_hot = torch.tensor([[0, 0, 0, 0, m-k(theta[0,4]-a), 0],
292
+ [0, 0, 0, 0, 0, m-k(theta[1,5]-a)],
293
+ [0, 0, 0, m-k(theta[2,3]-a), 0, 0],
294
+ ])
295
+ """
296
+ a = self.a
297
+ k = self.k
298
+ m_hot[range(0, index.size()[0]), label[index]] -= k * (cosine[index, label[index]].acos_() - a)
299
+ cosine[index] -= m_hot
300
+
301
+ """ Because we have used F.normalize, we should rescale the logit term by s.
302
+ """
303
+ output = cosine * self.s
304
+
305
+ return output
306
+
307
+ def __repr__(self):
308
+ return self.__class__.__name__ + '(' \
309
+ + 'in_features = ' + str(self.in_features) \
310
+ + ', out_features = ' + str(self.out_features) \
311
+ + ', s = ' + str(self.s) \
312
+ + ', m = ' + str(self.m) \
313
+ + ', a = ' + str(self.a) \
314
+ + ', k = ' + str(self.k) \
315
+ + ')'
316
+
317
+
318
+ class AMArcFace(nn.Module):
319
+ r"""Implementation of Adaptive Margin ArcFace:
320
+ cos(theta+m-k(theta-a))
321
+ When k is 0, AMArcFace degenerates into ArcFace.
322
+ Args:
323
+ in_features: dimension (d_in) of input feature (B, d_in)
324
+ out_features: dimension (d_out) of output feature (B, d_out)
325
+ device_id: the ID of GPU where the model will be trained by data parallel (or DP). (not used)
326
+ if device_id=None, it will be trained on model parallel (or DDP). (recommend!)
327
+ s: norm of input feature
328
+ m: margin
329
+ a: AM Loss
330
+ k: AM Loss
331
+ """
332
+ def __init__(self,
333
+ in_features: int,
334
+ out_features: int,
335
+ device_id,
336
+ s: float = 64.0,
337
+ m: float = 0.5,
338
+ a: float = 1.2,
339
+ k: float = 0.1,
340
+ ):
341
+ super(AMArcFace, self).__init__()
342
+ print('AMArcFace, s=%.1f, m=%.2f, a=%.2f, k=%.2f' % (s, m, a, k))
343
+ self.in_features = in_features
344
+ self.out_features = out_features
345
+ self.device_id = device_id
346
+
347
+ self.s = s
348
+ self.m = m
349
+ self.a = a
350
+ self.k = k
351
+
352
+ """ Weight Matrix W (d_out, d_in) """
353
+ self.weight = Parameter(torch.FloatTensor(out_features, in_features))
354
+ nn.init.xavier_uniform_(self.weight)
355
+
356
+ def forward(self, embedding, label):
357
+ """
358
+ :param embedding: learned face representation
359
+ :param label:
360
+ - label >= 0: ground truth identity
361
+ - label = -1: invalid identity for this GPU (refer to 'PartialFC')
362
+ + Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
363
+ :return:
364
+ """
365
+ if self.device_id is None:
366
+ """ - embedding: shape is (B, d_in)
367
+ - weight: shape is (d_out, d_in)
368
+ - cosine: shape is (B, d_out)
369
+ + F.normalize is very important here.
370
+ """
371
+ cosine = F.linear(F.normalize(embedding), F.normalize(self.weight)) # y = xA^T + b
372
+ else:
373
+ raise ValueError('DataParallel is not implemented yet.')
374
+ x = input
375
+ sub_weights = torch.chunk(self.weight, len(self.device_id), dim=0)
376
+ temp_x = x.cuda(self.device_id[0])
377
+ weight = sub_weights[0].cuda(self.device_id[0])
378
+ cosine = F.linear(F.normalize(temp_x), F.normalize(weight))
379
+ for i in range(1, len(self.device_id)):
380
+ temp_x = x.cuda(self.device_id[i])
381
+ weight = sub_weights[i].cuda(self.device_id[i])
382
+ cosine = torch.cat((cosine, F.linear(F.normalize(temp_x),
383
+ F.normalize(weight)).cuda(self.device_id[0])),
384
+ dim=1)
385
+
386
+ """ - index: the index of valid identity in label, shape is (d_valid, )
387
+ + torch.where() returns a tuple indicating the index of each dimension
388
+ + Example: index = torch.tensor([1, 3, 4])
389
+ """
390
+ index = torch.where(label != -1)[0]
391
+
392
+ """ - m_hot: one-hot tensor of margin m_2, shape is (d_valid, d_out)
393
+ + torch.tensor.scatter_(dim, index, source) is usually used to generate ont-hot tensor
394
+ + Example: label = torch.tensor([-1, 4, -1, 5, 3, -1])
395
+ index = torch.tensor([1, 3, 4]) # d_valid = index.shape[0] = 3
396
+ m_hot = torch.tensor([[0, 0, 0, 0, m, 0],
397
+ [0, 0, 0, 0, 0, m],
398
+ [0, 0, 0, m, 0, 0],
399
+ ])
400
+ """
401
+ m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
402
+ m_hot.scatter_(1, label[index, None], self.m)
403
+
404
+ """ logit(theta) = cos(theta) - m_2 + k * (theta - a)
405
+ - theta = cosine.acos_()
406
+ + Example: m_hot = torch.tensor([[0, 0, 0, 0, m-k(theta[0,4]-a), 0],
407
+ [0, 0, 0, 0, 0, m-k(theta[1,5]-a)],
408
+ [0, 0, 0, m-k(theta[2,3]-a), 0, 0],
409
+ ])
410
+ """
411
+ a = self.a
412
+ k = self.k
413
+ m_hot[range(0, index.size()[0]), label[index]] -= k * (cosine[index, label[index]].acos_() - a)
414
+
415
+ cosine.acos_()
416
+ cosine[index] += m_hot
417
+ cosine.cos_().mul_(self.s)
418
+ return cosine
419
+
420
+ def __repr__(self):
421
+ return self.__class__.__name__ + '(' \
422
+ + 'in_features = ' + str(self.in_features) \
423
+ + ', out_features = ' + str(self.out_features) \
424
+ + ', s = ' + str(self.s) \
425
+ + ', m = ' + str(self.m) \
426
+ + ', a = ' + str(self.a) \
427
+ + ', k = ' + str(self.k) \
428
+ + ')'
429
+
430
+
431
+ if __name__ == '__main__':
432
+ cosine = torch.randn(6, 8) / 100
433
+ cosine[0][2] = 0.3
434
+ cosine[1][4] = 0.4
435
+ cosine[2][6] = 0.5
436
+ cosine[3][5] = 0.6
437
+ cosine[4][3] = 0.7
438
+ cosine[5][0] = 0.8
439
+ label = torch.tensor([-1, 4, -1, 5, 3, -1])
440
+
441
+ # layer = AMCosFace(in_features=8,
442
+ # out_features=8,
443
+ # device_id=None,
444
+ # m=0.35, s=1.0,
445
+ # a=1.2, k=0.1)
446
+
447
+ # layer = Softmax(in_features=8,
448
+ # out_features=8,
449
+ # device_id=None)
450
+
451
+ layer = AMArcFace(in_features=8,
452
+ out_features=8,
453
+ device_id=None,
454
+ m=0.5, s=1.0,
455
+ a=1.2, k=0.1)
456
+
457
+ logit = layer(cosine, label)
458
+ logit = F.softmax(logit, dim=-1)
459
+
460
+ from utils.vis_tensor import plot_tensor
461
+ plot_tensor((cosine, logit),
462
+ ('embedding', 'logit'),
463
+ 'AMArc.jpg')
third_party/arcface/mouth_net.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ from third_party.arcface.iresnet import iresnet50, iresnet100
7
+
8
+ class MouthNet(nn.Module):
9
+ def __init__(self,
10
+ bisenet: nn.Module,
11
+ feature_dim: int = 64,
12
+ crop_param: tuple = (0, 0, 112, 112),
13
+ iresnet_pretrained: bool = False,
14
+ ):
15
+ super(MouthNet, self).__init__()
16
+
17
+ crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W)
18
+ fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7))
19
+
20
+ self.bisenet = bisenet
21
+ self.backbone = iresnet50(
22
+ pretrained=iresnet_pretrained,
23
+ num_features=feature_dim,
24
+ fp16=False,
25
+ fc_scale=fc_scale,
26
+ )
27
+
28
+ self.register_buffer(
29
+ name="vgg_mean",
30
+ tensor=torch.tensor([[[0.485]], [[0.456]], [[0.406]]], requires_grad=False),
31
+ )
32
+ self.register_buffer(
33
+ name="vgg_std",
34
+ tensor=torch.tensor([[[0.229]], [[0.224]], [[0.225]]], requires_grad=False),
35
+ )
36
+
37
+ def forward(self, x):
38
+ # with torch.no_grad():
39
+ # x_mouth_mask = self.get_any_mask(x, par=[11, 12, 13], normalized=True) # (B,1,H,W), in [0,1], 1:chosed
40
+ x_mouth_mask = 1
41
+ x_mouth = x * x_mouth_mask # (B,3,112,112)
42
+ mouth_feature = self.backbone(x_mouth)
43
+ return mouth_feature
44
+
45
+ def get_any_mask(self, img, par, normalized=False):
46
+ # [0, 'background', 1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye',
47
+ # 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 11 'mouth', 12 'u_lip',
48
+ # 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
49
+ ori_size = img.size()[-1]
50
+ with torch.no_grad():
51
+ img = F.interpolate(img, size=512, mode="nearest", )
52
+ if not normalized:
53
+ img = img * 0.5 + 0.5
54
+ img = img.sub(self.vgg_mean.detach()).div(self.vgg_std.detach())
55
+ out = self.bisenet(img)[0]
56
+ parsing = out.softmax(1).argmax(1)
57
+ mask = torch.zeros_like(parsing)
58
+ for p in par:
59
+ mask = mask + ((parsing == p).float())
60
+ mask = mask.unsqueeze(1)
61
+ mask = F.interpolate(mask, size=ori_size, mode="bilinear", align_corners=True)
62
+ return mask
63
+
64
+ def save_backbone(self, path: str):
65
+ torch.save(self.backbone.state_dict(), path)
66
+
67
+ def load_backbone(self, path: str):
68
+ self.backbone.load_state_dict(torch.load(path))
69
+
70
+
71
+ if __name__ == "__main__":
72
+ from third_party.bisenet.bisenet import BiSeNet
73
+
74
+ bisenet = BiSeNet(19)
75
+ bisenet.load_state_dict(
76
+ torch.load(
77
+ "/gavin/datasets/hanbang/79999_iter.pth",
78
+ map_location="cpu",
79
+ )
80
+ )
81
+ bisenet.eval()
82
+ bisenet.requires_grad_(False)
83
+
84
+ crop_param = (28, 56, 84, 112)
85
+
86
+ import numpy as np
87
+ img = np.random.randn(112, 112, 3) * 225
88
+ from PIL import Image
89
+ img = Image.fromarray(img.astype(np.uint8))
90
+ img = img.crop(crop_param)
91
+
92
+ from torchvision import transforms
93
+ trans = transforms.ToTensor()
94
+ img = trans(img).unsqueeze(0)
95
+ img = img.repeat(3, 1, 1, 1)
96
+ print(img.shape)
97
+
98
+ net = MouthNet(
99
+ bisenet=bisenet,
100
+ feature_dim=64,
101
+ crop_param=crop_param
102
+ )
103
+ mouth_feat = net(img)
104
+ print(mouth_feat.shape)
105
+
106
+ import thop
107
+
108
+ crop_size = (crop_param[3] - crop_param[1], crop_param[2] - crop_param[0]) # (H,W)
109
+ fc_scale = int(math.ceil(crop_size[0] / 112 * 7) * math.ceil(crop_size[1] / 112 * 7))
110
+ backbone = iresnet100(
111
+ pretrained=False,
112
+ num_features=64,
113
+ fp16=False,
114
+ # fc_scale=fc_scale,
115
+ )
116
+ flops, params = thop.profile(backbone, inputs=(torch.randn(1, 3, 112, 112),), verbose=False)
117
+ print('#Params=%.2fM, GFLOPS=%.2f' % (params / 1e6, flops / 1e9))
third_party/arcface/mouth_net_eval.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pytorch_lightning as pl
3
+ import numpy as np
4
+ import torch
5
+
6
+ from third_party.arcface.mouth_net_pl import MouthNetPL
7
+ from third_party.arcface.mouth_net import MouthNet
8
+
9
+
10
+ class MouthTest(object):
11
+ def __init__(self):
12
+ self.dataset_len = 400
13
+
14
+ self.fixer_crop_param = (28, 56, 84, 112)
15
+ self.fixer_casia_model = MouthNet(
16
+ bisenet=None,
17
+ feature_dim=128,
18
+ crop_param=self.fixer_crop_param
19
+ ).cuda()
20
+ fixer_path = "/gavin/code/FaceSwapping/modules/third_party/arcface/weights/fixer_net_casia_28_56_84_112.pth"
21
+ self.fixer_casia_model.load_backbone(fixer_path)
22
+ self.fixer_casia_model.eval()
23
+ self.fixer_t = np.zeros((self.dataset_len, 128), dtype=np.float32)
24
+ self.fixer_s = np.zeros_like(self.fixer_t, dtype=np.float32) # each embedding repeats 10 times in ffplus
25
+ self.fixer_r = np.zeros_like(self.fixer_t, dtype=np.float32)
26
+ print('Fixer model loaded.')
27
+
28
+
29
+ if __name__ == '__main__':
30
+
31
+ parser = argparse.ArgumentParser()
32
+ args = parser.parse_args()
33
+ args.val_targets = []
34
+ args.rec_folder = "/gavin/datasets/msml/ms1m-retinaface"
35
+
36
+ fixer_net = MouthNetPL.load_from_checkpoint(
37
+ "/apdcephfs/share_1290939/gavinyuan/out/fixernet_casia/epoch=22-step=10999-v1.ckpt",
38
+ map_location='cpu', strict=False,
39
+ num_classes=10572,
40
+ batch_size=128,
41
+ dim_feature=128,
42
+ rec_folder=args.rec_folder,
43
+ header_type="AMCosFace",
44
+ crop=(28, 56, 84, 112),
45
+ )
46
+
47
+ lower_net_1 = MouthNetPL.load_from_checkpoint(
48
+ "/apdcephfs/share_1290939/gavinyuan/out/mouth_net_1/epoch=24-step=242999.ckpt",
49
+ map_location='cpu', strict=False,
50
+ num_classes=93431,
51
+ batch_size=128,
52
+ dim_feature=128,
53
+ rec_folder=args.rec_folder,
54
+ header_type="AMArcFace",
55
+ crop=(28, 56, 84, 112),
56
+ )
57
+
58
+ # test_net = fixer_net
59
+ test_net = lower_net_1
60
+ trainer = pl.Trainer(
61
+ logger=False,
62
+ gpus=1,
63
+ distributed_backend='dp',
64
+ benchmark=True,
65
+ )
66
+ trainer.test(test_net)
67
+
68
+ # print('Fixer model loading...')
69
+ # m_test = MouthTest()
third_party/arcface/mouth_net_pl.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import torch
4
+ import torchvision
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import DataLoader
7
+ import pytorch_lightning as pl
8
+
9
+ import numpy as np
10
+ import sklearn
11
+ from sklearn.metrics import roc_curve, auc
12
+ from scipy.spatial.distance import cdist
13
+
14
+ from third_party.arcface.mouth_net import MouthNet
15
+ from third_party.arcface.margin_loss import Softmax, AMArcFace, AMCosFace
16
+ from third_party.arcface.load_dataset import MXFaceDataset, EvalDataset
17
+ from third_party.bisenet.bisenet import BiSeNet
18
+
19
+
20
+ class MouthNetPL(pl.LightningModule):
21
+ def __init__(
22
+ self,
23
+ num_classes: int,
24
+ batch_size: int = 256,
25
+ dim_feature: int = 128,
26
+ header_type: str = 'AMArcFace',
27
+ header_params: tuple = (64.0, 0.5, 0.0, 0.0), # (s, m, a, k)
28
+ rec_folder: str = "/gavin/datasets/msml/ms1m-retinaface",
29
+ learning_rate: int = 0.1,
30
+ crop: tuple = (0, 0, 112, 112), # (w1,h1,w2,h2)
31
+ ):
32
+ super(MouthNetPL, self).__init__()
33
+
34
+ # self.img_size = (112, 112)
35
+
36
+ ''' mouth feature extractor '''
37
+ bisenet = BiSeNet(19)
38
+ bisenet.load_state_dict(
39
+ torch.load(
40
+ "/gavin/datasets/hanbang/79999_iter.pth",
41
+ map_location="cpu",
42
+ )
43
+ )
44
+ bisenet.eval()
45
+ bisenet.requires_grad_(False)
46
+ self.mouth_net = MouthNet(
47
+ bisenet=None,
48
+ feature_dim=dim_feature,
49
+ crop_param=crop,
50
+ iresnet_pretrained=False,
51
+ )
52
+
53
+ ''' head & loss '''
54
+ self.automatic_optimization = False
55
+ self.dim_feature = dim_feature
56
+ self.num_classes = num_classes
57
+ self._prepare_header(header_type, header_params)
58
+ self.cls_criterion = torch.nn.CrossEntropyLoss()
59
+ self.learning_rate = learning_rate
60
+
61
+ ''' dataset '''
62
+ assert os.path.exists(rec_folder)
63
+ self.rec_folder = rec_folder
64
+ self.batch_size = batch_size
65
+ self.crop_param = crop
66
+
67
+ ''' validation '''
68
+
69
+ def _prepare_header(self, head_type, header_params):
70
+ dim_in = self.dim_feature
71
+ dim_out = self.num_classes
72
+
73
+ """ Get hyper-params of header """
74
+ s, m, a, k = header_params
75
+
76
+ """ Choose the header """
77
+ if 'Softmax' in head_type:
78
+ self.classification = Softmax(dim_in, dim_out, device_id=None)
79
+ elif 'AMCosFace' in head_type:
80
+ self.classification = AMCosFace(dim_in, dim_out,
81
+ device_id=None,
82
+ s=s, m=m,
83
+ a=a, k=k,
84
+ )
85
+ elif 'AMArcFace' in head_type:
86
+ self.classification = AMArcFace(dim_in, dim_out,
87
+ device_id=None,
88
+ s=s, m=m,
89
+ a=a, k=k,
90
+ )
91
+ else:
92
+ raise ValueError('Header type error!')
93
+
94
+ def forward(self, x, label=None):
95
+ feat = self.mouth_net(x)
96
+ if self.training:
97
+ assert label is not None
98
+ cls = self.classification(feat, label)
99
+ return feat, cls
100
+ else:
101
+ return feat
102
+
103
+ def training_step(self, batch, batch_idx):
104
+ opt = self.optimizers(use_pl_optimizer=True)
105
+ img, label = batch
106
+
107
+ mouth_feat, final_cls = self(img, label)
108
+
109
+ cls_loss = self.cls_criterion(final_cls, label)
110
+
111
+ opt.zero_grad()
112
+ self.manual_backward(cls_loss)
113
+ torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5, norm_type=2)
114
+ opt.step()
115
+
116
+ ''' loss logging '''
117
+ self.logging_dict({"cls_loss": cls_loss}, prefix="train / ")
118
+ self.logging_lr()
119
+ if batch_idx % 50 == 0 and self.local_rank == 0:
120
+ print('loss=', cls_loss)
121
+
122
+ return cls_loss
123
+
124
+ def training_epoch_end(self, outputs):
125
+ sch = self.lr_schedulers()
126
+ sch.step()
127
+
128
+ lr = -1
129
+ opts = self.trainer.optimizers
130
+ for opt in opts:
131
+ for param_group in opt.param_groups:
132
+ lr = param_group["lr"]
133
+ break
134
+ print('learning rate changed to %.6f' % lr)
135
+
136
+ # def validation_step(self, batch, batch_idx):
137
+ # return self.test_step(batch, batch_idx)
138
+ #
139
+ # def validation_step_end(self, outputs):
140
+ # return self.test_step_end(outputs)
141
+ #
142
+ # def validation_epoch_end(self, outputs):
143
+ # return self.test_step_end(outputs)
144
+
145
+ @staticmethod
146
+ def save_tensor(tensor: torch.Tensor, path: str, b_idx: int = 0):
147
+ tensor = (tensor + 1.) * 127.5
148
+ img = tensor.permute(0, 2, 3, 1)[b_idx].cpu().numpy()
149
+ from PIL import Image
150
+ img_pil = Image.fromarray(img.astype(np.uint8))
151
+ img_pil.save(path)
152
+
153
+ def test_step(self, batch, batch_idx):
154
+ img1, img2, same = batch
155
+ feat1 = self.mouth_net(img1)
156
+ feat2 = self.mouth_net(img2)
157
+ return feat1, feat2, same
158
+
159
+ def test_step_end(self, outputs):
160
+ feat1, feat2, same = outputs
161
+ feat1 = feat1.cpu().numpy()
162
+ feat2 = feat2.cpu().numpy()
163
+ same = same.cpu().numpy()
164
+
165
+ feat1 = sklearn.preprocessing.normalize(feat1)
166
+ feat2 = sklearn.preprocessing.normalize(feat2)
167
+
168
+ predict_label = []
169
+ num = feat1.shape[0]
170
+ for i in range(num):
171
+ dis_cos = cdist(feat1[i, None], feat2[i, None], metric='cosine')
172
+ predict_label.append(dis_cos[0, 0])
173
+ predict_label = np.array(predict_label)
174
+
175
+ return {
176
+ "pred": predict_label,
177
+ "gt": same,
178
+ }
179
+
180
+ def test_epoch_end(self, outputs):
181
+ print(outputs)
182
+ pred, same = None, None
183
+ for batch_output in outputs:
184
+ if pred is None and same is None:
185
+ pred = batch_output["pred"]
186
+ same = batch_output["gt"]
187
+ else:
188
+ pred = np.concatenate([pred, batch_output["pred"]])
189
+ same = np.concatenate([same, batch_output["gt"]])
190
+ print(pred.shape, same.shape)
191
+
192
+ fpr, tpr, threshold = roc_curve(same, pred)
193
+ acc = tpr[np.argmin(np.abs(tpr - (1 - fpr)))] # choose proper threshold
194
+ print("=> verification finished, acc=%.4f" % (acc))
195
+
196
+ ''' save pth '''
197
+ pth_path = "./weights/fixer_net_casia_%s.pth" % ('_'.join((str(x) for x in self.crop_param)))
198
+ self.mouth_net.save_backbone(pth_path)
199
+ print("=> model save to %s" % pth_path)
200
+ mouth_net = MouthNet(
201
+ bisenet=None,
202
+ feature_dim=self.dim_feature,
203
+ crop_param=self.crop_param
204
+ )
205
+ mouth_net.load_backbone(pth_path)
206
+ print("=> MouthNet pth checked")
207
+
208
+ return acc
209
+
210
+ def logging_dict(self, log_dict, prefix=None):
211
+ for key, val in log_dict.items():
212
+ if prefix is not None:
213
+ key = prefix + key
214
+ self.log(key, val)
215
+
216
+ def logging_lr(self):
217
+ opts = self.trainer.optimizers
218
+ for idx, opt in enumerate(opts):
219
+ lr = None
220
+ for param_group in opt.param_groups:
221
+ lr = param_group["lr"]
222
+ break
223
+ self.log(f"lr_{idx}", lr)
224
+
225
+ def configure_optimizers(self):
226
+ params = list(self.parameters())
227
+ learning_rate = self.learning_rate / 512 * self.batch_size * torch.cuda.device_count()
228
+ optimizer = torch.optim.SGD(params, lr=learning_rate,
229
+ momentum=0.9, weight_decay=5e-4)
230
+ print('lr is set as %.5f due to the global batch_size %d' % (learning_rate,
231
+ self.batch_size * torch.cuda.device_count()))
232
+
233
+ def lr_step_func(epoch):
234
+ return ((epoch + 1) / (4 + 1)) ** 2 if epoch < 0 else 0.1 ** len(
235
+ [m for m in [11, 17, 22] if m - 1 <= epoch]) # 0.1, 0.01, 0.001, 0.0001
236
+ scheduler= torch.optim.lr_scheduler.LambdaLR(
237
+ optimizer=optimizer, lr_lambda=lr_step_func)
238
+
239
+ return [optimizer], [scheduler]
240
+
241
+ def train_dataloader(self):
242
+ dataset = MXFaceDataset(
243
+ root_dir=self.rec_folder,
244
+ crop_param=self.crop_param,
245
+ )
246
+ train_loader = DataLoader(
247
+ dataset, self.batch_size, num_workers=24, shuffle=True, drop_last=True
248
+ )
249
+ return train_loader
250
+
251
+ def val_dataloader(self):
252
+ return self.test_dataloader()
253
+
254
+ def test_dataloader(self):
255
+ dataset = EvalDataset(
256
+ rec_folder=self.rec_folder,
257
+ target='lfw',
258
+ crop_param=self.crop_param
259
+ )
260
+ test_loader = DataLoader(
261
+ dataset, 20, num_workers=12, shuffle=False, drop_last=False
262
+ )
263
+ return test_loader
264
+
265
+
266
+ def start_train():
267
+ import os
268
+ import argparse
269
+ import torch
270
+ import pytorch_lightning as pl
271
+ from pytorch_lightning.callbacks import ModelCheckpoint
272
+ import wandb
273
+ from pytorch_lightning.loggers import WandbLogger
274
+
275
+ parser = argparse.ArgumentParser()
276
+ parser.add_argument(
277
+ "-g",
278
+ "--gpus",
279
+ type=str,
280
+ default=None,
281
+ help="Number of gpus to use (e.g. '0,1,2,3'). Will use all if not given.",
282
+ )
283
+ parser.add_argument("-n", "--name", type=str, required=True, help="Name of the run.")
284
+ parser.add_argument("-pj", "--project", type=str, default="mouthnet", help="Name of the project.")
285
+
286
+ parser.add_argument("-rp", "--resume_checkpoint_path",
287
+ type=str, default=None, help="path of checkpoint for resuming", )
288
+ parser.add_argument("-p", "--saving_folder",
289
+ type=str, default="/apdcephfs/share_1290939/gavinyuan/out", help="saving folder", )
290
+ parser.add_argument("--wandb_resume",
291
+ type=str, default=None, help="resume wandb logging from the input id", )
292
+
293
+ parser.add_argument("--header_type", type=str, default="AMArcFace", help="loss type.")
294
+
295
+ parser.add_argument("-bs", "--batch_size", type=int, default=128, help="bs.")
296
+ parser.add_argument("-fs", "--fast_dev_run", type=bool, default=False, help="pytorch.lightning fast_dev_run")
297
+ args = parser.parse_args()
298
+ args.val_targets = []
299
+ # args.rec_folder = "/gavin/datasets/msml/ms1m-retinaface"
300
+ # num_classes = 93431
301
+ args.rec_folder = "/gavin/datasets/msml/casia"
302
+ num_classes = 10572
303
+
304
+ save_path = os.path.join(args.saving_folder, args.name)
305
+ os.makedirs(save_path, exist_ok=True)
306
+ checkpoint_callback = ModelCheckpoint(
307
+ dirpath=save_path,
308
+ monitor="train / cls_loss",
309
+ save_top_k=10,
310
+ verbose=True,
311
+ every_n_train_steps=200,
312
+ )
313
+
314
+ torch.cuda.empty_cache()
315
+ mouth_net = MouthNetPL(
316
+ num_classes=num_classes,
317
+ batch_size=args.batch_size,
318
+ dim_feature=128,
319
+ rec_folder=args.rec_folder,
320
+ header_type=args.header_type,
321
+ crop=(28, 56, 84, 112)
322
+ )
323
+
324
+ if args.wandb_resume == None:
325
+ resume = "allow"
326
+ wandb_id = wandb.util.generate_id()
327
+ else:
328
+ resume = True
329
+ wandb_id = args.wandb_resume
330
+ logger = WandbLogger(
331
+ project=args.project,
332
+ entity="gavinyuan",
333
+ name=args.name,
334
+ resume=resume,
335
+ id=wandb_id,
336
+ )
337
+
338
+ trainer = pl.Trainer(
339
+ gpus=-1 if args.gpus is None else torch.cuda.device_count(),
340
+ callbacks=[checkpoint_callback],
341
+ logger=logger,
342
+ weights_save_path=save_path,
343
+ resume_from_checkpoint=args.resume_checkpoint_path,
344
+ gradient_clip_val=0,
345
+ max_epochs=25,
346
+ num_sanity_val_steps=1,
347
+ fast_dev_run=args.fast_dev_run,
348
+ val_check_interval=50,
349
+ progress_bar_refresh_rate=1,
350
+ distributed_backend="ddp",
351
+ benchmark=True,
352
+ )
353
+ trainer.fit(mouth_net)
354
+
355
+
356
+ if __name__ == "__main__":
357
+
358
+ start_train()
third_party/arcface/resnet.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ import torch
2
+ from torchvision.models import resnet50
third_party/arcface/utils_callbacks.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+ from typing import List
5
+
6
+ import torch
7
+
8
+ from third_party.arcface import verification
9
+
10
+
11
+ class AverageMeter(object):
12
+ """ Computes and stores the average and current value
13
+ """
14
+ def __init__(self):
15
+ self.val = None
16
+ self.avg = None
17
+ self.sum = None
18
+ self.count = None
19
+ self.reset()
20
+
21
+ def reset(self):
22
+ self.val = 0
23
+ self.avg = 0
24
+ self.sum = 0
25
+ self.count = 0
26
+
27
+ def update(self, val, n=1):
28
+ self.val = val
29
+ self.sum += val * n
30
+ self.count += n
31
+ self.avg = self.sum / self.count
32
+
33
+
34
+ class CallBackVerification(object):
35
+ def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112),
36
+ is_gray=False):
37
+ self.frequent: int = frequent
38
+ self.rank: int = rank
39
+ self.highest_acc: float = 0.0
40
+ self.highest_acc_list: List[float] = [0.0] * len(val_targets)
41
+ self.ver_list: List[object] = []
42
+ self.ver_name_list: List[str] = []
43
+ if self.rank is 0:
44
+ self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
45
+ self.is_gray = is_gray
46
+
47
+ def ver_test(self, backbone: torch.nn.Module, global_step: int):
48
+ results = []
49
+ for i in range(len(self.ver_list)):
50
+ acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
51
+ self.ver_list[i], backbone, 10, 10,
52
+ is_gray=self.is_gray)
53
+ # logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
54
+ # logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
55
+ print('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
56
+ print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
57
+ if acc2 > self.highest_acc_list[i]:
58
+ self.highest_acc_list[i] = acc2
59
+ # logging.info(
60
+ # '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
61
+ print(
62
+ '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
63
+ results.append(acc2)
64
+
65
+ def init_dataset(self, val_targets, data_dir, image_size):
66
+ for name in val_targets:
67
+ path = os.path.join(data_dir, name + ".bin")
68
+ if os.path.exists(path):
69
+ data_set = verification.load_bin(path, image_size)
70
+ self.ver_list.append(data_set)
71
+ self.ver_name_list.append(name)
72
+
73
+ def __call__(self, num_update, backbone: torch.nn.Module):
74
+ if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
75
+ backbone.eval()
76
+ self.ver_test(backbone, num_update)
77
+ backbone.train()
78
+
79
+
80
+ class CallBackLogging(object):
81
+ def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
82
+ self.frequent: int = frequent
83
+ self.rank: int = rank
84
+ self.time_start = time.time()
85
+ self.total_step: int = total_step
86
+ self.batch_size: int = batch_size
87
+ self.world_size: int = world_size
88
+ self.writer = writer
89
+
90
+ self.init = False
91
+ self.tic = 0
92
+
93
+ def __call__(self, global_step, loss: AverageMeter, epoch: int, fp16: bool, grad_scaler: torch.cuda.amp.GradScaler):
94
+ if self.rank is 0 and global_step > 0 and global_step % self.frequent == 0:
95
+ if self.init:
96
+ try:
97
+ speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
98
+ speed_total = speed * self.world_size
99
+ except ZeroDivisionError:
100
+ speed_total = float('inf')
101
+
102
+ time_now = (time.time() - self.time_start) / 3600
103
+ time_total = time_now / ((global_step + 1) / self.total_step)
104
+ time_for_end = time_total - time_now
105
+ if self.writer is not None:
106
+ self.writer.add_scalar('time_for_end', time_for_end, global_step)
107
+ self.writer.add_scalar('loss', loss.avg, global_step)
108
+ if fp16:
109
+ msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d "\
110
+ "Fp16 Grad Scale: %2.f Required: %1.f hours" % (
111
+ speed_total, loss.avg, epoch, global_step, grad_scaler.get_scale(), time_for_end
112
+ )
113
+ else:
114
+ msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % (
115
+ speed_total, loss.avg, epoch, global_step, time_for_end
116
+ )
117
+ logging.info(msg)
118
+ loss.reset()
119
+ self.tic = time.time()
120
+ else:
121
+ self.init = True
122
+ self.tic = time.time()
123
+
124
+
125
+ class CallBackModelCheckpoint(object):
126
+ def __init__(self, rank, output="./"):
127
+ self.rank: int = rank
128
+ self.output: str = output
129
+
130
+ def __call__(self,
131
+ global_step,
132
+ backbone: torch.nn.Module,
133
+ partial_fc=None,
134
+ awloss=None,):
135
+ print('CallBackModelCheckpoint...')
136
+ if global_step > 100 and self.rank is 0:
137
+ torch.save(backbone.module.state_dict(), os.path.join(self.output, "backbone.pth"))
138
+ if global_step > 100 and partial_fc is not None:
139
+ partial_fc.save_params()
140
+ if global_step > 100 and awloss is not None:
141
+ torch.save(awloss.state_dict(), os.path.join(self.output, "awloss.pth"))
third_party/arcface/verification.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper for evaluation on the Labeled Faces in the Wild dataset
2
+ """
3
+
4
+ # MIT License
5
+ #
6
+ # Copyright (c) 2016 David Sandberg
7
+ #
8
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ # of this software and associated documentation files (the "Software"), to deal
10
+ # in the Software without restriction, including without limitation the rights
11
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ # copies of the Software, and to permit persons to whom the Software is
13
+ # furnished to do so, subject to the following conditions:
14
+ #
15
+ # The above copyright notice and this permission notice shall be included in all
16
+ # copies or substantial portions of the Software.
17
+ #
18
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ # SOFTWARE.
25
+
26
+
27
+ import datetime
28
+ import os
29
+ import pickle
30
+
31
+ import mxnet as mx
32
+ import numpy as np
33
+ import sklearn
34
+ import torch
35
+ from mxnet import ndarray as nd
36
+ from scipy import interpolate
37
+ from sklearn.decomposition import PCA
38
+ from sklearn.model_selection import KFold
39
+
40
+
41
+ class LFold:
42
+ def __init__(self, n_splits=2, shuffle=False):
43
+ self.n_splits = n_splits
44
+ if self.n_splits > 1:
45
+ self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
46
+
47
+ def split(self, indices):
48
+ if self.n_splits > 1:
49
+ return self.k_fold.split(indices)
50
+ else:
51
+ return [(indices, indices)]
52
+
53
+
54
+ def calculate_roc(thresholds,
55
+ embeddings1,
56
+ embeddings2,
57
+ actual_issame,
58
+ nrof_folds=10,
59
+ pca=0):
60
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
61
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
62
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
63
+ nrof_thresholds = len(thresholds)
64
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
65
+
66
+ tprs = np.zeros((nrof_folds, nrof_thresholds))
67
+ fprs = np.zeros((nrof_folds, nrof_thresholds))
68
+ accuracy = np.zeros((nrof_folds))
69
+ indices = np.arange(nrof_pairs)
70
+
71
+ if pca == 0:
72
+ diff = np.subtract(embeddings1, embeddings2)
73
+ dist = np.sum(np.square(diff), 1)
74
+ print('dist', dist.min(), dist.max())
75
+
76
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
77
+ if pca > 0:
78
+ print('doing pca on', fold_idx)
79
+ embed1_train = embeddings1[train_set]
80
+ embed2_train = embeddings2[train_set]
81
+ _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
82
+ pca_model = PCA(n_components=pca)
83
+ pca_model.fit(_embed_train)
84
+ embed1 = pca_model.transform(embeddings1)
85
+ embed2 = pca_model.transform(embeddings2)
86
+ embed1 = sklearn.preprocessing.normalize(embed1)
87
+ embed2 = sklearn.preprocessing.normalize(embed2)
88
+ diff = np.subtract(embed1, embed2)
89
+ dist = np.sum(np.square(diff), 1)
90
+
91
+ # Find the best threshold for the fold
92
+ acc_train = np.zeros((nrof_thresholds))
93
+ for threshold_idx, threshold in enumerate(thresholds):
94
+ _, _, acc_train[threshold_idx] = calculate_accuracy(
95
+ threshold, dist[train_set], actual_issame[train_set])
96
+ best_threshold_index = np.argmax(acc_train)
97
+ for threshold_idx, threshold in enumerate(thresholds):
98
+ tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
99
+ threshold, dist[test_set],
100
+ actual_issame[test_set])
101
+ _, _, accuracy[fold_idx] = calculate_accuracy(
102
+ thresholds[best_threshold_index], dist[test_set],
103
+ actual_issame[test_set])
104
+
105
+ tpr = np.mean(tprs, 0)
106
+ fpr = np.mean(fprs, 0)
107
+ return tpr, fpr, accuracy
108
+
109
+
110
+ def calculate_accuracy(threshold, dist, actual_issame):
111
+ predict_issame = np.less(dist, threshold)
112
+ tp = np.sum(np.logical_and(predict_issame, actual_issame))
113
+ fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
114
+ tn = np.sum(
115
+ np.logical_and(np.logical_not(predict_issame),
116
+ np.logical_not(actual_issame)))
117
+ fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
118
+
119
+ tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
120
+ fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
121
+ acc = float(tp + tn) / dist.size
122
+ return tpr, fpr, acc
123
+
124
+
125
+ def calculate_val(thresholds,
126
+ embeddings1,
127
+ embeddings2,
128
+ actual_issame,
129
+ far_target,
130
+ nrof_folds=10):
131
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
132
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
133
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
134
+ nrof_thresholds = len(thresholds)
135
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
136
+
137
+ val = np.zeros(nrof_folds)
138
+ far = np.zeros(nrof_folds)
139
+
140
+ diff = np.subtract(embeddings1, embeddings2)
141
+ dist = np.sum(np.square(diff), 1)
142
+ indices = np.arange(nrof_pairs)
143
+
144
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
145
+
146
+ # Find the threshold that gives FAR = far_target
147
+ far_train = np.zeros(nrof_thresholds)
148
+ for threshold_idx, threshold in enumerate(thresholds):
149
+ _, far_train[threshold_idx] = calculate_val_far(
150
+ threshold, dist[train_set], actual_issame[train_set])
151
+ if np.max(far_train) >= far_target:
152
+ f = interpolate.interp1d(far_train, thresholds, kind='slinear')
153
+ threshold = f(far_target)
154
+ else:
155
+ threshold = 0.0
156
+
157
+ val[fold_idx], far[fold_idx] = calculate_val_far(
158
+ threshold, dist[test_set], actual_issame[test_set])
159
+
160
+ val_mean = np.mean(val)
161
+ far_mean = np.mean(far)
162
+ val_std = np.std(val)
163
+ return val_mean, val_std, far_mean
164
+
165
+
166
+ def calculate_val_far(threshold, dist, actual_issame):
167
+ predict_issame = np.less(dist, threshold)
168
+ true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
169
+ false_accept = np.sum(
170
+ np.logical_and(predict_issame, np.logical_not(actual_issame)))
171
+ n_same = np.sum(actual_issame)
172
+ n_diff = np.sum(np.logical_not(actual_issame))
173
+ # print(true_accept, false_accept)
174
+ # print(actual_issame)
175
+ # print(n_same, n_diff)
176
+ val = float(true_accept) / float(n_same)
177
+ far = float(false_accept) / float(n_diff)
178
+ return val, far
179
+
180
+
181
+ def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
182
+ # Calculate evaluation metrics
183
+ thresholds = np.arange(0, 4, 0.01)
184
+ embeddings1 = embeddings[0::2]
185
+ embeddings2 = embeddings[1::2]
186
+ tpr, fpr, accuracy = calculate_roc(thresholds,
187
+ embeddings1,
188
+ embeddings2,
189
+ np.asarray(actual_issame),
190
+ nrof_folds=nrof_folds,
191
+ pca=pca)
192
+ thresholds = np.arange(0, 4, 0.001)
193
+ val, val_std, far = calculate_val(thresholds,
194
+ embeddings1,
195
+ embeddings2,
196
+ np.asarray(actual_issame),
197
+ 1e-3,
198
+ nrof_folds=nrof_folds)
199
+ return tpr, fpr, accuracy, val, val_std, far
200
+
201
+ @torch.no_grad()
202
+ def load_bin(path, image_size):
203
+ try:
204
+ with open(path, 'rb') as f:
205
+ bins, issame_list = pickle.load(f) # py2
206
+ except UnicodeDecodeError as e:
207
+ with open(path, 'rb') as f:
208
+ bins, issame_list = pickle.load(f, encoding='bytes') # py3
209
+ data_list = []
210
+ for flip in [0, 1]:
211
+ data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
212
+ data_list.append(data)
213
+ for idx in range(len(issame_list) * 2):
214
+ _bin = bins[idx]
215
+ img = mx.image.imdecode(_bin)
216
+ if img.shape[1] != image_size[0]:
217
+ img = mx.image.resize_short(img, image_size[0])
218
+ img = nd.transpose(img, axes=(2, 0, 1)) # (C, H, W)
219
+ for flip in [0, 1]:
220
+ if flip == 1:
221
+ img = mx.ndarray.flip(data=img, axis=2)
222
+ data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
223
+ if idx % 1000 == 0:
224
+ print('loading bin', idx)
225
+
226
+ # # save img to '/home/yuange/dataset/LFW/rgb-arcface'
227
+ # img = nd.transpose(img, axes=(1, 2, 0)) # (H, W, C)
228
+ # save_name = 'ind_' + str(idx) + '.bmp'
229
+ # import os
230
+ # save_name = os.path.join('/home/yuange/dataset/LFW/rgb-arcface', save_name)
231
+ # import PIL.Image as Image
232
+ # fig = Image.fromarray(img.asnumpy(), mode='RGB')
233
+ # fig.save(save_name)
234
+
235
+ print('load finished', data_list[0].shape)
236
+ return data_list, issame_list
237
+
238
+ @torch.no_grad()
239
+ def test(data_set, backbone, batch_size, nfolds=10,
240
+ is_gray=False,):
241
+ print('testing verification..')
242
+ data_list = data_set[0]
243
+ issame_list = data_set[1]
244
+ embeddings_list = []
245
+ time_consumed = 0.0
246
+ for i in range(len(data_list)):
247
+ data = data_list[i] # (12000, 3, 112, 112)
248
+
249
+ print(data.shape)
250
+ if is_gray:
251
+ data = (0.2989 * data[:, 0] +
252
+ 0.5870 * data[:, 1] +
253
+ 0.1140 * data[:, 2]) / 3
254
+ data = data[:, None, :, :]
255
+ print(data.shape)
256
+
257
+ embeddings = None
258
+ ba = 0
259
+ while ba < data.shape[0]:
260
+ bb = min(ba + batch_size, data.shape[0])
261
+ count = bb - ba
262
+ _data = data[bb - batch_size: bb]
263
+ time0 = datetime.datetime.now()
264
+
265
+ if not is_gray:
266
+ img = ((_data / 255) - 0.5) / 0.5
267
+ else:
268
+ img = _data / 255
269
+
270
+ # mouth_net returns a feature whether in training or testing
271
+ feature = backbone(img.cuda(0))
272
+ net_out: torch.Tensor = feature
273
+
274
+ _embeddings = net_out.detach().cpu().numpy()
275
+ time_now = datetime.datetime.now()
276
+ diff = time_now - time0
277
+ time_consumed += diff.total_seconds()
278
+ if embeddings is None:
279
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
280
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
281
+ ba = bb
282
+ embeddings_list.append(embeddings)
283
+
284
+ print('emb_list', len(embeddings_list), embeddings_list[0].size, embeddings_list[1].size)
285
+ _xnorm = 0.0
286
+ _xnorm_cnt = 0
287
+ for embed in embeddings_list:
288
+ for i in range(embed.shape[0]):
289
+ _em = embed[i]
290
+ _norm = np.linalg.norm(_em)
291
+ _xnorm += _norm
292
+ _xnorm_cnt += 1
293
+ _xnorm /= _xnorm_cnt
294
+
295
+ embeddings = embeddings_list[0].copy()
296
+ embeddings = sklearn.preprocessing.normalize(embeddings)
297
+ acc1 = 0.0
298
+ std1 = 0.0
299
+ embeddings = embeddings_list[0] + embeddings_list[1]
300
+ embeddings = sklearn.preprocessing.normalize(embeddings)
301
+ print('embeddings.shape', embeddings.shape)
302
+ print('infer time', time_consumed)
303
+ _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
304
+ acc2, std2 = np.mean(accuracy), np.std(accuracy)
305
+ return acc1, std1, acc2, std2, _xnorm, embeddings_list
306
+
307
+
308
+ def dumpR(data_set,
309
+ backbone,
310
+ batch_size,
311
+ name='',
312
+ data_extra=None,
313
+ label_shape=None):
314
+ print('dump verification embedding..')
315
+ data_list = data_set[0]
316
+ issame_list = data_set[1]
317
+ embeddings_list = []
318
+ time_consumed = 0.0
319
+ for i in range(len(data_list)):
320
+ data = data_list[i]
321
+ embeddings = None
322
+ ba = 0
323
+ while ba < data.shape[0]:
324
+ bb = min(ba + batch_size, data.shape[0])
325
+ count = bb - ba
326
+
327
+ _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
328
+ time0 = datetime.datetime.now()
329
+ if data_extra is None:
330
+ db = mx.io.DataBatch(data=(_data,), label=(_label,))
331
+ else:
332
+ db = mx.io.DataBatch(data=(_data, _data_extra),
333
+ label=(_label,))
334
+ model.forward(db, is_train=False)
335
+ net_out = model.get_outputs()
336
+ _embeddings = net_out[0].asnumpy()
337
+ time_now = datetime.datetime.now()
338
+ diff = time_now - time0
339
+ time_consumed += diff.total_seconds()
340
+ if embeddings is None:
341
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
342
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
343
+ ba = bb
344
+ embeddings_list.append(embeddings)
345
+ embeddings = embeddings_list[0] + embeddings_list[1]
346
+ embeddings = sklearn.preprocessing.normalize(embeddings)
347
+ actual_issame = np.asarray(issame_list)
348
+ outname = os.path.join('temp.bin')
349
+ with open(outname, 'wb') as f:
350
+ pickle.dump((embeddings, issame_list),
351
+ f,
352
+ protocol=pickle.HIGHEST_PROTOCOL)
353
+
354
+
355
+ # if __name__ == '__main__':
356
+ #
357
+ # parser = argparse.ArgumentParser(description='do verification')
358
+ # # general
359
+ # parser.add_argument('--data-dir', default='', help='')
360
+ # parser.add_argument('--model',
361
+ # default='../model/softmax,50',
362
+ # help='path to load model.')
363
+ # parser.add_argument('--target',
364
+ # default='lfw,cfp_ff,cfp_fp,agedb_30',
365
+ # help='test targets.')
366
+ # parser.add_argument('--gpu', default=0, type=int, help='gpu id')
367
+ # parser.add_argument('--batch-size', default=32, type=int, help='')
368
+ # parser.add_argument('--max', default='', type=str, help='')
369
+ # parser.add_argument('--mode', default=0, type=int, help='')
370
+ # parser.add_argument('--nfolds', default=10, type=int, help='')
371
+ # args = parser.parse_args()
372
+ # image_size = [112, 112]
373
+ # print('image_size', image_size)
374
+ # ctx = mx.gpu(args.gpu)
375
+ # nets = []
376
+ # vec = args.model.split(',')
377
+ # prefix = args.model.split(',')[0]
378
+ # epochs = []
379
+ # if len(vec) == 1:
380
+ # pdir = os.path.dirname(prefix)
381
+ # for fname in os.listdir(pdir):
382
+ # if not fname.endswith('.params'):
383
+ # continue
384
+ # _file = os.path.join(pdir, fname)
385
+ # if _file.startswith(prefix):
386
+ # epoch = int(fname.split('.')[0].split('-')[1])
387
+ # epochs.append(epoch)
388
+ # epochs = sorted(epochs, reverse=True)
389
+ # if len(args.max) > 0:
390
+ # _max = [int(x) for x in args.max.split(',')]
391
+ # assert len(_max) == 2
392
+ # if len(epochs) > _max[1]:
393
+ # epochs = epochs[_max[0]:_max[1]]
394
+ #
395
+ # else:
396
+ # epochs = [int(x) for x in vec[1].split('|')]
397
+ # print('model number', len(epochs))
398
+ # time0 = datetime.datetime.now()
399
+ # for epoch in epochs:
400
+ # print('loading', prefix, epoch)
401
+ # sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
402
+ # # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
403
+ # all_layers = sym.get_internals()
404
+ # sym = all_layers['fc1_output']
405
+ # model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
406
+ # # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
407
+ # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
408
+ # image_size[1]))])
409
+ # model.set_params(arg_params, aux_params)
410
+ # nets.append(model)
411
+ # time_now = datetime.datetime.now()
412
+ # diff = time_now - time0
413
+ # print('model loading time', diff.total_seconds())
414
+ #
415
+ # ver_list = []
416
+ # ver_name_list = []
417
+ # for name in args.target.split(','):
418
+ # path = os.path.join(args.data_dir, name + ".bin")
419
+ # if os.path.exists(path):
420
+ # print('loading.. ', name)
421
+ # data_set = load_bin(path, image_size)
422
+ # ver_list.append(data_set)
423
+ # ver_name_list.append(name)
424
+ #
425
+ # if args.mode == 0:
426
+ # for i in range(len(ver_list)):
427
+ # results = []
428
+ # for model in nets:
429
+ # acc1, std1, acc2, std2, xnorm, embeddings_list = test(
430
+ # ver_list[i], model, args.batch_size, args.nfolds)
431
+ # print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
432
+ # print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
433
+ # print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
434
+ # results.append(acc2)
435
+ # print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
436
+ # elif args.mode == 1:
437
+ # raise ValueError
438
+ # else:
439
+ # model = nets[0]
440
+ # dumpR(ver_list[0], model, args.batch_size, args.target)