Spanicin commited on
Commit
ad9d237
·
verified ·
1 Parent(s): 6c8c772

Update videoretalking/third_part/GPEN/face_detect/facemodels/retinaface.py

Browse files
videoretalking/third_part/GPEN/face_detect/facemodels/retinaface.py CHANGED
@@ -1,127 +1,127 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.models.detection.backbone_utils as backbone_utils
4
- import torchvision.models._utils as _utils
5
- import torch.nn.functional as F
6
- from collections import OrderedDict
7
-
8
- from face_detect.facemodels.net import MobileNetV1 as MobileNetV1
9
- from face_detect.facemodels.net import FPN as FPN
10
- from face_detect.facemodels.net import SSH as SSH
11
-
12
-
13
-
14
- class ClassHead(nn.Module):
15
- def __init__(self,inchannels=512,num_anchors=3):
16
- super(ClassHead,self).__init__()
17
- self.num_anchors = num_anchors
18
- self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
19
-
20
- def forward(self,x):
21
- out = self.conv1x1(x)
22
- out = out.permute(0,2,3,1).contiguous()
23
-
24
- return out.view(out.shape[0], -1, 2)
25
-
26
- class BboxHead(nn.Module):
27
- def __init__(self,inchannels=512,num_anchors=3):
28
- super(BboxHead,self).__init__()
29
- self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
30
-
31
- def forward(self,x):
32
- out = self.conv1x1(x)
33
- out = out.permute(0,2,3,1).contiguous()
34
-
35
- return out.view(out.shape[0], -1, 4)
36
-
37
- class LandmarkHead(nn.Module):
38
- def __init__(self,inchannels=512,num_anchors=3):
39
- super(LandmarkHead,self).__init__()
40
- self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
41
-
42
- def forward(self,x):
43
- out = self.conv1x1(x)
44
- out = out.permute(0,2,3,1).contiguous()
45
-
46
- return out.view(out.shape[0], -1, 10)
47
-
48
- class RetinaFace(nn.Module):
49
- def __init__(self, cfg = None, phase = 'train'):
50
- """
51
- :param cfg: Network related settings.
52
- :param phase: train or test.
53
- """
54
- super(RetinaFace,self).__init__()
55
- self.phase = phase
56
- backbone = None
57
- if cfg['name'] == 'mobilenet0.25':
58
- backbone = MobileNetV1()
59
- if cfg['pretrain']:
60
- checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
61
- from collections import OrderedDict
62
- new_state_dict = OrderedDict()
63
- for k, v in checkpoint['state_dict'].items():
64
- name = k[7:] # remove module.
65
- new_state_dict[name] = v
66
- # load params
67
- backbone.load_state_dict(new_state_dict)
68
- elif cfg['name'] == 'Resnet50':
69
- import torchvision.models as models
70
- backbone = models.resnet50(pretrained=cfg['pretrain'])
71
-
72
- self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
73
- in_channels_stage2 = cfg['in_channel']
74
- in_channels_list = [
75
- in_channels_stage2 * 2,
76
- in_channels_stage2 * 4,
77
- in_channels_stage2 * 8,
78
- ]
79
- out_channels = cfg['out_channel']
80
- self.fpn = FPN(in_channels_list,out_channels)
81
- self.ssh1 = SSH(out_channels, out_channels)
82
- self.ssh2 = SSH(out_channels, out_channels)
83
- self.ssh3 = SSH(out_channels, out_channels)
84
-
85
- self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
86
- self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
87
- self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
88
-
89
- def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
90
- classhead = nn.ModuleList()
91
- for i in range(fpn_num):
92
- classhead.append(ClassHead(inchannels,anchor_num))
93
- return classhead
94
-
95
- def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
96
- bboxhead = nn.ModuleList()
97
- for i in range(fpn_num):
98
- bboxhead.append(BboxHead(inchannels,anchor_num))
99
- return bboxhead
100
-
101
- def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
102
- landmarkhead = nn.ModuleList()
103
- for i in range(fpn_num):
104
- landmarkhead.append(LandmarkHead(inchannels,anchor_num))
105
- return landmarkhead
106
-
107
- def forward(self,inputs):
108
- out = self.body(inputs)
109
-
110
- # FPN
111
- fpn = self.fpn(out)
112
-
113
- # SSH
114
- feature1 = self.ssh1(fpn[0])
115
- feature2 = self.ssh2(fpn[1])
116
- feature3 = self.ssh3(fpn[2])
117
- features = [feature1, feature2, feature3]
118
-
119
- bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
120
- classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
121
- ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
122
-
123
- if self.phase == 'train':
124
- output = (bbox_regressions, classifications, ldm_regressions)
125
- else:
126
- output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
127
  return output
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.models.detection.backbone_utils as backbone_utils
4
+ import torchvision.models._utils as _utils
5
+ import torch.nn.functional as F
6
+ from collections import OrderedDict
7
+
8
+ from videoretalking.third_part.GPEN.face_detect.facemodels.net import MobileNetV1 as MobileNetV1
9
+ from videoretalking.third_part.GPEN.face_detect.facemodels.net import FPN as FPN
10
+ from videoretalking.third_part.GPEN.face_detect.facemodels.net import SSH as SSH
11
+
12
+
13
+
14
+ class ClassHead(nn.Module):
15
+ def __init__(self,inchannels=512,num_anchors=3):
16
+ super(ClassHead,self).__init__()
17
+ self.num_anchors = num_anchors
18
+ self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
19
+
20
+ def forward(self,x):
21
+ out = self.conv1x1(x)
22
+ out = out.permute(0,2,3,1).contiguous()
23
+
24
+ return out.view(out.shape[0], -1, 2)
25
+
26
+ class BboxHead(nn.Module):
27
+ def __init__(self,inchannels=512,num_anchors=3):
28
+ super(BboxHead,self).__init__()
29
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
30
+
31
+ def forward(self,x):
32
+ out = self.conv1x1(x)
33
+ out = out.permute(0,2,3,1).contiguous()
34
+
35
+ return out.view(out.shape[0], -1, 4)
36
+
37
+ class LandmarkHead(nn.Module):
38
+ def __init__(self,inchannels=512,num_anchors=3):
39
+ super(LandmarkHead,self).__init__()
40
+ self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
41
+
42
+ def forward(self,x):
43
+ out = self.conv1x1(x)
44
+ out = out.permute(0,2,3,1).contiguous()
45
+
46
+ return out.view(out.shape[0], -1, 10)
47
+
48
+ class RetinaFace(nn.Module):
49
+ def __init__(self, cfg = None, phase = 'train'):
50
+ """
51
+ :param cfg: Network related settings.
52
+ :param phase: train or test.
53
+ """
54
+ super(RetinaFace,self).__init__()
55
+ self.phase = phase
56
+ backbone = None
57
+ if cfg['name'] == 'mobilenet0.25':
58
+ backbone = MobileNetV1()
59
+ if cfg['pretrain']:
60
+ checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
61
+ from collections import OrderedDict
62
+ new_state_dict = OrderedDict()
63
+ for k, v in checkpoint['state_dict'].items():
64
+ name = k[7:] # remove module.
65
+ new_state_dict[name] = v
66
+ # load params
67
+ backbone.load_state_dict(new_state_dict)
68
+ elif cfg['name'] == 'Resnet50':
69
+ import torchvision.models as models
70
+ backbone = models.resnet50(pretrained=cfg['pretrain'])
71
+
72
+ self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
73
+ in_channels_stage2 = cfg['in_channel']
74
+ in_channels_list = [
75
+ in_channels_stage2 * 2,
76
+ in_channels_stage2 * 4,
77
+ in_channels_stage2 * 8,
78
+ ]
79
+ out_channels = cfg['out_channel']
80
+ self.fpn = FPN(in_channels_list,out_channels)
81
+ self.ssh1 = SSH(out_channels, out_channels)
82
+ self.ssh2 = SSH(out_channels, out_channels)
83
+ self.ssh3 = SSH(out_channels, out_channels)
84
+
85
+ self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
86
+ self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
87
+ self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
88
+
89
+ def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
90
+ classhead = nn.ModuleList()
91
+ for i in range(fpn_num):
92
+ classhead.append(ClassHead(inchannels,anchor_num))
93
+ return classhead
94
+
95
+ def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
96
+ bboxhead = nn.ModuleList()
97
+ for i in range(fpn_num):
98
+ bboxhead.append(BboxHead(inchannels,anchor_num))
99
+ return bboxhead
100
+
101
+ def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
102
+ landmarkhead = nn.ModuleList()
103
+ for i in range(fpn_num):
104
+ landmarkhead.append(LandmarkHead(inchannels,anchor_num))
105
+ return landmarkhead
106
+
107
+ def forward(self,inputs):
108
+ out = self.body(inputs)
109
+
110
+ # FPN
111
+ fpn = self.fpn(out)
112
+
113
+ # SSH
114
+ feature1 = self.ssh1(fpn[0])
115
+ feature2 = self.ssh2(fpn[1])
116
+ feature3 = self.ssh3(fpn[2])
117
+ features = [feature1, feature2, feature3]
118
+
119
+ bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
120
+ classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
121
+ ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
122
+
123
+ if self.phase == 'train':
124
+ output = (bbox_regressions, classifications, ldm_regressions)
125
+ else:
126
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
127
  return output