FaceAdapter commited on
Commit
703c10a
1 Parent(s): 242e411

Update third_party/insightface_backbone_conv.py

Browse files
Files changed (1) hide show
  1. third_party/insightface_backbone_conv.py +236 -236
third_party/insightface_backbone_conv.py CHANGED
@@ -1,237 +1,237 @@
1
- import os
2
- import torch
3
- from torch import nn
4
-
5
- __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200', 'getarcface']
6
-
7
-
8
- def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9
- """3x3 convolution with padding"""
10
- return nn.Conv2d(in_planes,
11
- out_planes,
12
- kernel_size=3,
13
- stride=stride,
14
- padding=dilation,
15
- groups=groups,
16
- bias=False,
17
- dilation=dilation)
18
-
19
-
20
- def conv1x1(in_planes, out_planes, stride=1):
21
- """1x1 convolution"""
22
- return nn.Conv2d(in_planes,
23
- out_planes,
24
- kernel_size=1,
25
- stride=stride,
26
- bias=False)
27
-
28
-
29
- class IBasicBlock(nn.Module):
30
- expansion = 1
31
- def __init__(self, inplanes, planes, stride=1, downsample=None,
32
- groups=1, base_width=64, dilation=1):
33
- super(IBasicBlock, self).__init__()
34
- if groups != 1 or base_width != 64:
35
- raise ValueError('BasicBlock only supports groups=1 and base_width=64')
36
- if dilation > 1:
37
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
38
- self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
39
- self.conv1 = conv3x3(inplanes, planes)
40
- self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
41
- self.prelu = nn.PReLU(planes)
42
- self.conv2 = conv3x3(planes, planes, stride)
43
- self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
44
- self.downsample = downsample
45
- self.stride = stride
46
-
47
- def forward(self, x):
48
- identity = x
49
- out = self.bn1(x)
50
- out = self.conv1(out)
51
- out = self.bn2(out)
52
- out = self.prelu(out)
53
- out = self.conv2(out)
54
- out = self.bn3(out)
55
- if self.downsample is not None:
56
- identity = self.downsample(x)
57
- out += identity
58
- return out
59
-
60
-
61
- class IResNet(nn.Module):
62
- fc_scale = 7 * 7
63
- def __init__(self,
64
- block, layers, dropout=0, num_features=512, zero_init_residual=False,
65
- groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
66
- super(IResNet, self).__init__()
67
- self.fp16 = fp16
68
- self.inplanes = 64
69
- self.dilation = 1
70
- if replace_stride_with_dilation is None:
71
- replace_stride_with_dilation = [False, False, False]
72
- if len(replace_stride_with_dilation) != 3:
73
- raise ValueError("replace_stride_with_dilation should be None "
74
- "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
75
- self.groups = groups
76
- self.base_width = width_per_group
77
- self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
78
- self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
79
- self.prelu = nn.PReLU(self.inplanes)
80
- self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
81
- self.layer2 = self._make_layer(block,
82
- 128,
83
- layers[1],
84
- stride=2,
85
- dilate=replace_stride_with_dilation[0])
86
- self.layer3 = self._make_layer(block,
87
- 256,
88
- layers[2],
89
- stride=2,
90
- dilate=replace_stride_with_dilation[1])
91
- self.layer4 = self._make_layer(block,
92
- 512,
93
- layers[3],
94
- stride=2,
95
- dilate=replace_stride_with_dilation[2])
96
- self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
97
- self.dropout = nn.Dropout(p=dropout, inplace=True)
98
- self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
99
- self.features = nn.BatchNorm1d(num_features, eps=1e-05)
100
- nn.init.constant_(self.features.weight, 1.0)
101
- self.features.weight.requires_grad = False
102
-
103
- for m in self.modules():
104
- if isinstance(m, nn.Conv2d):
105
- nn.init.normal_(m.weight, 0, 0.1)
106
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
107
- nn.init.constant_(m.weight, 1)
108
- nn.init.constant_(m.bias, 0)
109
-
110
- if zero_init_residual:
111
- for m in self.modules():
112
- if isinstance(m, IBasicBlock):
113
- nn.init.constant_(m.bn2.weight, 0)
114
-
115
- def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
116
- downsample = None
117
- previous_dilation = self.dilation
118
- if dilate:
119
- self.dilation *= stride
120
- stride = 1
121
- if stride != 1 or self.inplanes != planes * block.expansion:
122
- downsample = nn.Sequential(
123
- conv1x1(self.inplanes, planes * block.expansion, stride),
124
- nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
125
- )
126
- layers = []
127
- layers.append(
128
- block(self.inplanes, planes, stride, downsample, self.groups,
129
- self.base_width, previous_dilation))
130
- self.inplanes = planes * block.expansion
131
- for _ in range(1, blocks):
132
- layers.append(
133
- block(self.inplanes,
134
- planes,
135
- groups=self.groups,
136
- base_width=self.base_width,
137
- dilation=self.dilation))
138
-
139
- return nn.Sequential(*layers)
140
-
141
- def forward(self, x, return_id512=False):
142
-
143
- bz = x.shape[0]
144
- # with torch.cuda.amp.autocast(self.fp16):
145
- x = self.conv1(x)
146
- x = self.bn1(x)
147
- x = self.prelu(x)
148
- x = self.layer1(x)
149
- x = self.layer2(x)
150
- x = self.layer3(x)
151
- x = self.layer4(x)
152
- if not return_id512:
153
- return x.view(bz,512,-1).permute(0,2,1).contiguous()
154
- else:
155
- x = self.bn2(x)
156
- x = torch.flatten(x, 1)
157
- # x = self.dropout(x)
158
- # x = self.fc(x.float() if self.fp16 else x)
159
- x = self.fc(x)
160
- x = self.features(x)
161
- return x
162
-
163
-
164
-
165
- def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
166
- model = IResNet(block, layers, **kwargs)
167
- if pretrained:
168
- raise ValueError()
169
- return model
170
-
171
-
172
- def iresnet18(pretrained=False, progress=True, **kwargs):
173
- return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
174
- progress, **kwargs)
175
-
176
-
177
- def iresnet34(pretrained=False, progress=True, **kwargs):
178
- return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
179
- progress, **kwargs)
180
-
181
-
182
- def iresnet50(pretrained=False, progress=True, **kwargs):
183
- return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
184
- progress, **kwargs)
185
-
186
-
187
- def iresnet100(pretrained=False, progress=True, **kwargs):
188
- return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
189
- progress, **kwargs)
190
-
191
-
192
- def iresnet200(pretrained=False, progress=True, **kwargs):
193
- return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
194
- progress, **kwargs)
195
-
196
-
197
- def getarcface(pretrained=None):
198
- model = iresnet100().eval()
199
- for param in model.parameters():
200
- param.requires_grad=False
201
-
202
- if pretrained is not None and os.path.exists(pretrained):
203
- info = model.load_state_dict(torch.load(pretrained))
204
- print(info)
205
- return model
206
-
207
-
208
- if __name__=='__main__':
209
- ckpt = 'pretrained/insightface_glint360k.pth'
210
- arcface = iresnet100().eval()
211
- info = arcface.load_state_dict(torch.load(ckpt))
212
- print(info)
213
-
214
- id = arcface(torch.randn(1,3,128,128))
215
- print(id.shape)
216
-
217
- # import cv2
218
- # import numpy as np
219
- # im1_crop256 = cv2.imread('happy.jpg')
220
- # im2_crop256 = cv2.imread('angry.jpg')
221
-
222
- # im1_crop112 = cv2.resize(im1_crop256, (128,128))[0:112,8:120,:]
223
- # im2_crop112 = cv2.resize(im2_crop256, (128,128))[0:112,8:120,:]
224
-
225
- # cv2.imwrite('1_112.jpg', im1_crop112)
226
- # cv2.imwrite('2_112.jpg', im2_crop112)
227
-
228
- # # [-1,1] rgb
229
- # im1_crop112_tensor = torch.from_numpy(im1_crop112[:,:,[2,1,0]].transpose(2, 0, 1).astype(np.float32)).unsqueeze(0)/127.5-1
230
- # im2_crop112_tensor = torch.from_numpy(im2_crop112[:,:,[2,1,0]].transpose(2, 0, 1).astype(np.float32)).unsqueeze(0)/127.5-1
231
-
232
- # im1_id = arcface(im1_crop112_tensor)
233
- # im2_id = arcface(im2_crop112_tensor)
234
-
235
- # loss_cos = torch.mean(1-torch.cosine_similarity(im1_id, im2_id, dim=1))
236
-
237
  # print(loss_cos)
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+
5
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200', 'getarcface']
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9
+ """3x3 convolution with padding"""
10
+ return nn.Conv2d(in_planes,
11
+ out_planes,
12
+ kernel_size=3,
13
+ stride=stride,
14
+ padding=dilation,
15
+ groups=groups,
16
+ bias=False,
17
+ dilation=dilation)
18
+
19
+
20
+ def conv1x1(in_planes, out_planes, stride=1):
21
+ """1x1 convolution"""
22
+ return nn.Conv2d(in_planes,
23
+ out_planes,
24
+ kernel_size=1,
25
+ stride=stride,
26
+ bias=False)
27
+
28
+
29
+ class IBasicBlock(nn.Module):
30
+ expansion = 1
31
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
32
+ groups=1, base_width=64, dilation=1):
33
+ super(IBasicBlock, self).__init__()
34
+ if groups != 1 or base_width != 64:
35
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
36
+ if dilation > 1:
37
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
38
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
39
+ self.conv1 = conv3x3(inplanes, planes)
40
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
41
+ self.prelu = nn.PReLU(planes)
42
+ self.conv2 = conv3x3(planes, planes, stride)
43
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
44
+ self.downsample = downsample
45
+ self.stride = stride
46
+
47
+ def forward(self, x):
48
+ identity = x
49
+ out = self.bn1(x)
50
+ out = self.conv1(out)
51
+ out = self.bn2(out)
52
+ out = self.prelu(out)
53
+ out = self.conv2(out)
54
+ out = self.bn3(out)
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+ out += identity
58
+ return out
59
+
60
+
61
+ class IResNet(nn.Module):
62
+ fc_scale = 7 * 7
63
+ def __init__(self,
64
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
65
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
66
+ super(IResNet, self).__init__()
67
+ self.fp16 = fp16
68
+ self.inplanes = 64
69
+ self.dilation = 1
70
+ if replace_stride_with_dilation is None:
71
+ replace_stride_with_dilation = [False, False, False]
72
+ if len(replace_stride_with_dilation) != 3:
73
+ raise ValueError("replace_stride_with_dilation should be None "
74
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
75
+ self.groups = groups
76
+ self.base_width = width_per_group
77
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
78
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
79
+ self.prelu = nn.PReLU(self.inplanes)
80
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
81
+ self.layer2 = self._make_layer(block,
82
+ 128,
83
+ layers[1],
84
+ stride=2,
85
+ dilate=replace_stride_with_dilation[0])
86
+ self.layer3 = self._make_layer(block,
87
+ 256,
88
+ layers[2],
89
+ stride=2,
90
+ dilate=replace_stride_with_dilation[1])
91
+ self.layer4 = self._make_layer(block,
92
+ 512,
93
+ layers[3],
94
+ stride=2,
95
+ dilate=replace_stride_with_dilation[2])
96
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
97
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
98
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
99
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
100
+ nn.init.constant_(self.features.weight, 1.0)
101
+ self.features.weight.requires_grad = False
102
+
103
+ for m in self.modules():
104
+ if isinstance(m, nn.Conv2d):
105
+ nn.init.normal_(m.weight, 0, 0.1)
106
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
107
+ nn.init.constant_(m.weight, 1)
108
+ nn.init.constant_(m.bias, 0)
109
+
110
+ if zero_init_residual:
111
+ for m in self.modules():
112
+ if isinstance(m, IBasicBlock):
113
+ nn.init.constant_(m.bn2.weight, 0)
114
+
115
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
116
+ downsample = None
117
+ previous_dilation = self.dilation
118
+ if dilate:
119
+ self.dilation *= stride
120
+ stride = 1
121
+ if stride != 1 or self.inplanes != planes * block.expansion:
122
+ downsample = nn.Sequential(
123
+ conv1x1(self.inplanes, planes * block.expansion, stride),
124
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
125
+ )
126
+ layers = []
127
+ layers.append(
128
+ block(self.inplanes, planes, stride, downsample, self.groups,
129
+ self.base_width, previous_dilation))
130
+ self.inplanes = planes * block.expansion
131
+ for _ in range(1, blocks):
132
+ layers.append(
133
+ block(self.inplanes,
134
+ planes,
135
+ groups=self.groups,
136
+ base_width=self.base_width,
137
+ dilation=self.dilation))
138
+
139
+ return nn.Sequential(*layers)
140
+
141
+ def forward(self, x, return_id512=False):
142
+
143
+ bz = x.shape[0]
144
+ # with torch.cuda.amp.autocast(self.fp16):
145
+ x = self.conv1(x)
146
+ x = self.bn1(x)
147
+ x = self.prelu(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ if not return_id512:
153
+ return x.view(bz,512,-1).permute(0,2,1).contiguous()
154
+ else:
155
+ x = self.bn2(x)
156
+ x = torch.flatten(x, 1)
157
+ # x = self.dropout(x)
158
+ # x = self.fc(x.float() if self.fp16 else x)
159
+ x = self.fc(x)
160
+ x = self.features(x)
161
+ return x
162
+
163
+
164
+
165
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
166
+ model = IResNet(block, layers, **kwargs)
167
+ if pretrained:
168
+ raise ValueError()
169
+ return model
170
+
171
+
172
+ def iresnet18(pretrained=False, progress=True, **kwargs):
173
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
174
+ progress, **kwargs)
175
+
176
+
177
+ def iresnet34(pretrained=False, progress=True, **kwargs):
178
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
179
+ progress, **kwargs)
180
+
181
+
182
+ def iresnet50(pretrained=False, progress=True, **kwargs):
183
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
184
+ progress, **kwargs)
185
+
186
+
187
+ def iresnet100(pretrained=False, progress=True, **kwargs):
188
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
189
+ progress, **kwargs)
190
+
191
+
192
+ def iresnet200(pretrained=False, progress=True, **kwargs):
193
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
194
+ progress, **kwargs)
195
+
196
+
197
+ def getarcface(pretrained=None):
198
+ model = iresnet100().eval()
199
+ for param in model.parameters():
200
+ param.requires_grad=False
201
+
202
+ if pretrained is not None and os.path.exists(pretrained):
203
+ info = model.load_state_dict(torch.load(pretrained, map_location=lambda storage, loc: storage))
204
+ print(info)
205
+ return model
206
+
207
+
208
+ if __name__=='__main__':
209
+ ckpt = 'pretrained/insightface_glint360k.pth'
210
+ arcface = iresnet100().eval()
211
+ info = arcface.load_state_dict(torch.load(ckpt))
212
+ print(info)
213
+
214
+ id = arcface(torch.randn(1,3,128,128))
215
+ print(id.shape)
216
+
217
+ # import cv2
218
+ # import numpy as np
219
+ # im1_crop256 = cv2.imread('happy.jpg')
220
+ # im2_crop256 = cv2.imread('angry.jpg')
221
+
222
+ # im1_crop112 = cv2.resize(im1_crop256, (128,128))[0:112,8:120,:]
223
+ # im2_crop112 = cv2.resize(im2_crop256, (128,128))[0:112,8:120,:]
224
+
225
+ # cv2.imwrite('1_112.jpg', im1_crop112)
226
+ # cv2.imwrite('2_112.jpg', im2_crop112)
227
+
228
+ # # [-1,1] rgb
229
+ # im1_crop112_tensor = torch.from_numpy(im1_crop112[:,:,[2,1,0]].transpose(2, 0, 1).astype(np.float32)).unsqueeze(0)/127.5-1
230
+ # im2_crop112_tensor = torch.from_numpy(im2_crop112[:,:,[2,1,0]].transpose(2, 0, 1).astype(np.float32)).unsqueeze(0)/127.5-1
231
+
232
+ # im1_id = arcface(im1_crop112_tensor)
233
+ # im2_id = arcface(im2_crop112_tensor)
234
+
235
+ # loss_cos = torch.mean(1-torch.cosine_similarity(im1_id, im2_id, dim=1))
236
+
237
  # print(loss_cos)