Enderfga commited on
Commit
7652882
1 Parent(s): ff76503

Add application file

Browse files
README.md CHANGED
@@ -1,13 +1,125 @@
1
- ---
2
- title: MtCNN Sysu
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.12.0
8
- app_file: app.py
9
- pinned: false
10
- license: openrail
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Networks
2
+
3
+ This repo contains the code, data and trained models for the paper [Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Networks](https://arxiv.org/ftp/arxiv/papers/1604/1604.02878.pdf).
4
+
5
+ ## Overview
6
+
7
+ MTCNN is a popular algorithm for face detection that uses multiple neural networks to detect faces in images. It is capable of detecting faces under various lighting and pose conditions and can detect multiple faces in an image.
8
+
9
+ We have implemented MTCNN using the pytorch framework. Pytorch is a popular deep learning framework that provides tools for building and training neural networks.
10
+
11
+ ![](https://img.enderfga.cn/img/image-20221208152130975.png)
12
+
13
+ ![](https://img.enderfga.cn/img/image-20221208152231511.png)
14
+ ## Description of file
15
+ ```shell
16
+ ├── README.md # explanatory document
17
+ ├── get_data.py # Generate corresponding training data depending on the input “--net”
18
+ ├── img # mid.png is used for testing visualization effects,other images are the corresponding results.
19
+ │ ├── mid.png
20
+ │   ├── onet.png
21
+ │   ├── pnet.png
22
+ │   ├── rnet.png
23
+ │   ├── result.png
24
+ │   └── result.jpg
25
+ ├── model_store # Our pre-trained model
26
+ │   ├── onet_epoch_20.pt
27
+ │   ├── pnet_epoch_20.pt
28
+ │   └── rnet_epoch_20.pt
29
+ ├── requirements.txt # Environmental version requirements
30
+ ├── test.py # Specify different "--net" to get the corresponding visualization results
31
+ ├── test.sh # Used to test mid.png, which will test the output visualization of three networks
32
+ ├── train.out # Our complete training log for this experiment
33
+ ├── train.py # Specify different "--net" for the training of the corresponding network
34
+ ├── train.sh # Generate data from start to finish and train
35
+ └── utils # Some common tool functions and modules
36
+ ├── config.py
37
+ ├── dataloader.py
38
+ ├── detect.py
39
+ ├── models.py
40
+ ├── tool.py
41
+ └── vision.py
42
+ ```
43
+ ## Requirements
44
+
45
+ * numpy==1.21.4
46
+ * matplotlib==3.5.0
47
+ * opencv-python==4.4.0.42
48
+ * torch==1.13.0+cu116
49
+
50
+ ## How to Install
51
+
52
+ - ```shell
53
+ conda create -n env python=3.8 -y
54
+ conda activate env
55
+ ```
56
+ - ```shell
57
+ pip install -r requirements.txt
58
+ ```
59
+
60
+ ## Preprocessing
61
+
62
+ - download [WIDER_FACE](http://shuoyang1213.me/WIDERFACE/) face detection data then store it into ./data_set/face_detection
63
+ - download [CNN_FacePoint](http://mmlab.ie.cuhk.edu.hk/archive/CNN_FacePoint.htm) face detection and landmark data then store it into ./data_set/face_landmark
64
+
65
+ ### Preprocessed Data
66
+
67
+ ```shell
68
+ # Before training Pnet
69
+ python get_data.py --net=pnet
70
+ # Before training Rnet, please use your trained model path
71
+ python get_data.py --net=rnet --pnet_path=./model_store/pnet_epoch_20.pt
72
+ # Before training Onet, please use your trained model path
73
+ python get_data.py --net=onet --pnet_path=./model_store/pnet_epoch_20.pt --rnet_path=./model_store/rnet_epoch_20.pt
74
+ ```
75
+
76
+ ## How to Run
77
+
78
+ ### Train
79
+
80
+ ```shell
81
+ python train.py --net=pnet/rnet/onet #Specify the corresponding network to start training
82
+ bash train.sh #Alternatively, use the sh file to train in order
83
+ ```
84
+
85
+ The checkpoints will be saved in a subfolder of `./model_store/*`.
86
+
87
+ #### Finetuning from an existing checkpoint
88
+
89
+ ```shell
90
+ python train.py --net=pnet/rnet/onet --load=[model path]
91
+ ```
92
+
93
+ model path should be a subdirectory in the `./model_store/` directory, e.g. `--load=./model_store/pnet_epoch_20.pt`
94
+
95
+ ### Evaluate
96
+
97
+ #### Use the sh file to test in order
98
+
99
+ ```shell
100
+ bash test.sh
101
+ ```
102
+
103
+ #### To detect a single image
104
+
105
+ ```shell
106
+ python test.py --net=pnet/rnet/onet --path=test.jpg
107
+ ```
108
+
109
+ #### To detect a video stream from a camera
110
+
111
+ ```shell
112
+ python test.py --input_mode=0
113
+ ```
114
+
115
+ #### The result of "--net=pnet"
116
+
117
+ ![](https://img.enderfga.cn/img/20221208160900.png)
118
+
119
+ #### The result of "--net=rnet"
120
+
121
+ ![](https://img.enderfga.cn/img/image-20221208155022083.png)
122
+
123
+ #### The result of "--net=onet"
124
+
125
+ ![](https://img.enderfga.cn/img/image-20221208155044451.png)
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ from utils.detect import create_mtcnn_net, MtcnnDetector
4
+ from utils.vision import vis_face
5
+ import argparse
6
+
7
+
8
+ MIN_FACE_SIZE = 3
9
+
10
+ def parse_args():
11
+ parser = argparse.ArgumentParser(description='Test MTCNN',
12
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
13
+
14
+ parser.add_argument('--net', default='onet', help='which net to show', type=str)
15
+ parser.add_argument('--pnet_path', default="./model_store/pnet_epoch_20.pt",help='path to pnet model', type=str)
16
+ parser.add_argument('--rnet_path', default="./model_store/rnet_epoch_20.pt",help='path to rnet model', type=str)
17
+ parser.add_argument('--onet_path', default="./model_store/onet_epoch_20.pt",help='path to onet model', type=str)
18
+ parser.add_argument('--path', default="./img/mid.png",help='path to image', type=str)
19
+ parser.add_argument('--min_face_size', default=MIN_FACE_SIZE,help='min face size', type=int)
20
+ parser.add_argument('--use_cuda', default=False,help='use cuda', type=bool)
21
+ parser.add_argument('--thresh', default='[0.1, 0.1, 0.1]',help='thresh', type=str)
22
+ parser.add_argument('--save_name', default="result.jpg",help='save name', type=str)
23
+ parser.add_argument('--input_mode', default=1,help='image or video', type=int)
24
+ args = parser.parse_args()
25
+ return args
26
+ def greet(name):
27
+ args = parse_args()
28
+ thresh = [float(i) for i in (args.thresh).split('[')[1].split(']')[0].split(',')]
29
+ pnet, rnet, onet = create_mtcnn_net(p_model_path=args.pnet_path, r_model_path=args.rnet_path,o_model_path=args.onet_path, use_cuda=args.use_cuda)
30
+ mtcnn_detector = MtcnnDetector(pnet=pnet, rnet=rnet, onet=onet, min_face_size=args.min_face_size,threshold=thresh)
31
+ img = cv2.imread(name)
32
+ img_bg = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
33
+ p_bboxs, r_bboxs, bboxs, landmarks = mtcnn_detector.detect_face(img)
34
+ save_name = args.save_name
35
+ return vis_face(img_bg, bboxs, landmarks, MIN_FACE_SIZE, save_name)
36
+ iface = gr.Interface(fn=greet,
37
+ inputs=gr.Image(type="filepath"),
38
+ outputs="image")
39
+ iface.launch()
get_data.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import cv2
4
+ import os
5
+ from utils.tool import IoU,convert_to_square
6
+ import numpy.random as npr
7
+ import argparse
8
+ from utils.detect import MtcnnDetector, create_mtcnn_net
9
+ from utils.dataloader import ImageDB,TestImageLoader
10
+ import time
11
+ from six.moves import cPickle
12
+ import utils.config as config
13
+ import utils.vision as vision
14
+ sys.path.append(os.getcwd())
15
+
16
+
17
+ txt_from_path = './data_set/wider_face_train_bbx_gt.txt'
18
+ anno_file = os.path.join(config.ANNO_STORE_DIR, 'anno_train.txt')
19
+ # anno_file = './anno_store/anno_train.txt'
20
+
21
+ prefix = ''
22
+ use_cuda = True
23
+ im_dir = "./data_set/face_detection/WIDER_train/images/"
24
+ traindata_store = './data_set/train/'
25
+ prefix_path = "./data_set/face_detection/WIDER_train/images/"
26
+ annotation_file = './anno_store/anno_train.txt'
27
+ prefix_path_lm = ''
28
+ annotation_file_lm = "./data_set/face_landmark/CNN_FacePoint/train/trainImageList.txt"
29
+ # ----------------------------------------------------other----------------------------------------------
30
+ pos_save_dir = "./data_set/train/12/positive"
31
+ part_save_dir = "./data_set/train/12/part"
32
+ neg_save_dir = './data_set/train/12/negative'
33
+ pnet_postive_file = os.path.join(config.ANNO_STORE_DIR, 'pos_12.txt')
34
+ pnet_part_file = os.path.join(config.ANNO_STORE_DIR, 'part_12.txt')
35
+ pnet_neg_file = os.path.join(config.ANNO_STORE_DIR, 'neg_12.txt')
36
+ imglist_filename_pnet = os.path.join(config.ANNO_STORE_DIR, 'imglist_anno_12.txt')
37
+ # ----------------------------------------------------PNet----------------------------------------------
38
+ rnet_postive_file = os.path.join(config.ANNO_STORE_DIR, 'pos_24.txt')
39
+ rnet_part_file = os.path.join(config.ANNO_STORE_DIR, 'part_24.txt')
40
+ rnet_neg_file = os.path.join(config.ANNO_STORE_DIR, 'neg_24.txt')
41
+ rnet_landmark_file = os.path.join(config.ANNO_STORE_DIR, 'landmark_24.txt')
42
+ imglist_filename_rnet = os.path.join(config.ANNO_STORE_DIR, 'imglist_anno_24.txt')
43
+ # ----------------------------------------------------RNet----------------------------------------------
44
+ onet_postive_file = os.path.join(config.ANNO_STORE_DIR, 'pos_48.txt')
45
+ onet_part_file = os.path.join(config.ANNO_STORE_DIR, 'part_48.txt')
46
+ onet_neg_file = os.path.join(config.ANNO_STORE_DIR, 'neg_48.txt')
47
+ onet_landmark_file = os.path.join(config.ANNO_STORE_DIR, 'landmark_48.txt')
48
+ imglist_filename_onet = os.path.join(config.ANNO_STORE_DIR, 'imglist_anno_48.txt')
49
+ # ----------------------------------------------------ONet----------------------------------------------
50
+
51
+
52
+
53
+ def assemble_data(output_file, anno_file_list=[]):
54
+
55
+ #assemble the pos, neg, part annotations to one file
56
+ size = 12
57
+
58
+ if len(anno_file_list)==0:
59
+ return 0
60
+
61
+ if os.path.exists(output_file):
62
+ os.remove(output_file)
63
+
64
+ for anno_file in anno_file_list:
65
+ with open(anno_file, 'r') as f:
66
+ print(anno_file)
67
+ anno_lines = f.readlines()
68
+
69
+ base_num = 250000
70
+
71
+ if len(anno_lines) > base_num * 3:
72
+ idx_keep = npr.choice(len(anno_lines), size=base_num * 3, replace=True)
73
+ elif len(anno_lines) > 100000:
74
+ idx_keep = npr.choice(len(anno_lines), size=len(anno_lines), replace=True)
75
+ else:
76
+ idx_keep = np.arange(len(anno_lines))
77
+ np.random.shuffle(idx_keep)
78
+ chose_count = 0
79
+ with open(output_file, 'a+') as f:
80
+ for idx in idx_keep:
81
+ # write lables of pos, neg, part images
82
+ f.write(anno_lines[idx])
83
+ chose_count+=1
84
+
85
+ return chose_count
86
+ def wider_face(txt_from_path, txt_to_path):
87
+ line_from_count = 0
88
+ with open(txt_from_path, 'r') as f:
89
+ annotations = f.readlines()
90
+ with open(txt_to_path, 'w+') as f:
91
+ while line_from_count < len(annotations):
92
+ if annotations[line_from_count][2]=='-':
93
+ img_name = annotations[line_from_count][:-1]
94
+ line_from_count += 1 # change line to read the number
95
+ bbox_count = int(annotations[line_from_count]) # num of bboxes
96
+ line_from_count += 1 # change line to read the posession
97
+ for _ in range(bbox_count):
98
+ bbox = list(map(int,annotations[line_from_count].split()[:4])) # give a loop to append all the boxes
99
+ bbox = [bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]] # make x1, y1, w, h --> x1, y1, x2, y2
100
+ bbox = list(map(str,bbox))
101
+ img_name += (' '+' '.join(bbox))
102
+ line_from_count+=1
103
+ f.write(img_name +'\n')
104
+ else: # dectect the file name
105
+ line_from_count+=1
106
+
107
+ # ----------------------------------------------------origin----------------------------------------------
108
+ def get_Pnet_data():
109
+ if not os.path.exists(pos_save_dir):
110
+ os.makedirs(pos_save_dir)
111
+ if not os.path.exists(part_save_dir):
112
+ os.makedirs(part_save_dir)
113
+ if not os.path.exists(neg_save_dir):
114
+ os.makedirs(neg_save_dir)
115
+ f1 = open(os.path.join('./anno_store', 'pos_12.txt'), 'w')
116
+ f2 = open(os.path.join('./anno_store', 'neg_12.txt'), 'w')
117
+ f3 = open(os.path.join('./anno_store', 'part_12.txt'), 'w')
118
+ with open(anno_file, 'r') as f:
119
+ annotations = f.readlines()
120
+ num = len(annotations)
121
+ print("%d pics in total" % num)
122
+ p_idx = 0 # positive
123
+ n_idx = 0 # negative
124
+ d_idx = 0 # dont care
125
+ idx = 0
126
+ box_idx = 0
127
+ for annotation in annotations:
128
+ annotation = annotation.strip().split(' ')
129
+ # annotation[0]文件名
130
+ im_path = os.path.join(im_dir, annotation[0])
131
+ # print(im_path)
132
+ # print(os.path.exists(im_path))
133
+ bbox = list(map(float, annotation[1:]))
134
+ # annotation[1:]人脸坐标,一张脸4个值,对应两个点的坐标
135
+ boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4)
136
+ # -1处的值为人脸数目
137
+ if boxes.shape[0]==0:
138
+ continue
139
+ # 若无人脸则跳过本次循环
140
+ img = cv2.imread(im_path)
141
+ # print(img.shape)
142
+ # exit()
143
+ # 计数
144
+ idx += 1
145
+ if idx % 100 == 0:
146
+ print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx))
147
+
148
+ # 图片三通道
149
+ height, width, channel = img.shape
150
+
151
+ neg_num = 0
152
+
153
+ # 取50次不同的框
154
+ while neg_num < 50:
155
+ size = np.random.randint(12, min(width, height) / 2)
156
+ nx = np.random.randint(0, width - size)
157
+ ny = np.random.randint(0, height - size)
158
+ crop_box = np.array([nx, ny, nx + size, ny + size])
159
+
160
+ Iou = IoU(crop_box, boxes) # IoU为 重合部分 / 两框之和 ,越大越好
161
+
162
+ cropped_im = img[ny: ny + size, nx: nx + size, :] # 裁去多余部分并resize成 12*12
163
+ resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR)
164
+
165
+ if np.max(Iou) < 0.3:
166
+ # Iou with all gts must below 0.3
167
+ save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx)
168
+ f2.write(save_file + ' 0\n')
169
+ cv2.imwrite(save_file, resized_im)
170
+ n_idx += 1
171
+ neg_num += 1
172
+
173
+ for box in boxes:
174
+ # box (x_left, y_top, x_right, y_bottom)
175
+ x1, y1, x2, y2 = box
176
+ # w = x2 - x1 + 1
177
+ # h = y2 - y1 + 1
178
+ w = x2 - x1 + 1
179
+ h = y2 - y1 + 1
180
+
181
+ # ignore small faces
182
+ # in case the ground truth boxes of small faces are not accurate
183
+ if max(w, h) < 40 or x1 < 0 or y1 < 0:
184
+ continue
185
+ if w < 12 or h < 12:
186
+ continue
187
+
188
+ # generate negative examples that have overlap with gt
189
+ for i in range(5):
190
+ size = np.random.randint(12, min(width, height) / 2)
191
+
192
+ # delta_x and delta_y are offsets of (x1, y1)
193
+ delta_x = np.random.randint(max(-size, -x1), w)
194
+ delta_y = np.random.randint(max(-size, -y1), h)
195
+ nx1 = max(0, x1 + delta_x)
196
+ ny1 = max(0, y1 + delta_y)
197
+
198
+ if nx1 + size > width or ny1 + size > height:
199
+ continue
200
+ crop_box = np.array([nx1, ny1, nx1 + size, ny1 + size])
201
+ Iou = IoU(crop_box, boxes)
202
+
203
+ cropped_im = img[ny1: ny1 + size, nx1: nx1 + size, :]
204
+ resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR)
205
+
206
+ if np.max(Iou) < 0.3:
207
+ # Iou with all gts must below 0.3
208
+ save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx)
209
+ f2.write(save_file + ' 0\n')
210
+ cv2.imwrite(save_file, resized_im)
211
+ n_idx += 1
212
+
213
+ # generate positive examples and part faces
214
+ for i in range(20):
215
+ size = np.random.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h)))
216
+
217
+ # delta here is the offset of box center
218
+ delta_x = np.random.randint(-w * 0.2, w * 0.2)
219
+ delta_y = np.random.randint(-h * 0.2, h * 0.2)
220
+
221
+ nx1 = max(x1 + w / 2 + delta_x - size / 2, 0)
222
+ ny1 = max(y1 + h / 2 + delta_y - size / 2, 0)
223
+ nx2 = nx1 + size
224
+ ny2 = ny1 + size
225
+
226
+ if nx2 > width or ny2 > height:
227
+ continue
228
+ crop_box = np.array([nx1, ny1, nx2, ny2])
229
+
230
+ offset_x1 = (x1 - nx1) / float(size)
231
+ offset_y1 = (y1 - ny1) / float(size)
232
+ offset_x2 = (x2 - nx2) / float(size)
233
+ offset_y2 = (y2 - ny2) / float(size)
234
+
235
+ cropped_im = img[int(ny1): int(ny2), int(nx1): int(nx2), :]
236
+ resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR)
237
+
238
+ box_ = box.reshape(1, -1)
239
+ if IoU(crop_box, box_) >= 0.65:
240
+ save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx)
241
+ f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2))
242
+ cv2.imwrite(save_file, resized_im)
243
+ p_idx += 1
244
+ elif IoU(crop_box, box_) >= 0.4:
245
+ save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx)
246
+ f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2))
247
+ cv2.imwrite(save_file, resized_im)
248
+ d_idx += 1
249
+ box_idx += 1
250
+ #print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx))
251
+
252
+ f1.close()
253
+ f2.close()
254
+ f3.close()
255
+
256
+
257
+ def assembel_Pnet_data():
258
+ anno_list = []
259
+
260
+ anno_list.append(pnet_postive_file)
261
+ anno_list.append(pnet_part_file)
262
+ anno_list.append(pnet_neg_file)
263
+ # anno_list.append(pnet_landmark_file)
264
+ chose_count = assemble_data(imglist_filename_pnet ,anno_list)
265
+ print("PNet train annotation result file path:%s" % imglist_filename_pnet)
266
+
267
+ # -----------------------------------------------------------------------------------------------------------------------------------------------#
268
+
269
+ def gen_rnet_data(data_dir, anno_file, pnet_model_file, prefix_path='', use_cuda=True, vis=False):
270
+
271
+ """
272
+ :param data_dir: train data
273
+ :param anno_file:
274
+ :param pnet_model_file:
275
+ :param prefix_path:
276
+ :param use_cuda:
277
+ :param vis:
278
+ :return:
279
+ """
280
+
281
+ # load trained pnet model
282
+
283
+ pnet, _, _ = create_mtcnn_net(p_model_path = pnet_model_file, use_cuda = use_cuda)
284
+ mtcnn_detector = MtcnnDetector(pnet = pnet, min_face_size = 12)
285
+
286
+ # load original_anno_file, length = 12880
287
+ imagedb = ImageDB(anno_file, mode = "test", prefix_path = prefix_path)
288
+ imdb = imagedb.load_imdb()
289
+ image_reader = TestImageLoader(imdb, 1, False)
290
+
291
+ all_boxes = list()
292
+ batch_idx = 0
293
+
294
+ print('size:%d' %image_reader.size)
295
+ for databatch in image_reader:
296
+ if batch_idx % 100 == 0:
297
+ print ("%d images done" % batch_idx)
298
+ im = databatch
299
+ t = time.time()
300
+
301
+ # obtain boxes and aligned boxes
302
+ boxes, boxes_align = mtcnn_detector.detect_pnet(im=im)
303
+ if boxes_align is None:
304
+ all_boxes.append(np.array([]))
305
+ batch_idx += 1
306
+ continue
307
+ if vis:
308
+ rgb_im = cv2.cvtColor(np.asarray(im), cv2.COLOR_BGR2RGB)
309
+ vision.vis_two(rgb_im, boxes, boxes_align)
310
+
311
+ t1 = time.time() - t
312
+ print('cost time ',t1)
313
+ t = time.time()
314
+ all_boxes.append(boxes_align)
315
+ batch_idx += 1
316
+ # if batch_idx == 100:
317
+ # break
318
+ # print("shape of all boxes {0}".format(all_boxes))
319
+ # time.sleep(5)
320
+
321
+ # save_path = model_store_path()
322
+ # './model_store'
323
+ save_path = './model_store'
324
+
325
+ if not os.path.exists(save_path):
326
+ os.mkdir(save_path)
327
+
328
+ save_file = os.path.join(save_path, "detections_%d.pkl" % int(time.time()))
329
+ with open(save_file, 'wb') as f:
330
+ cPickle.dump(all_boxes, f, cPickle.HIGHEST_PROTOCOL)
331
+
332
+ # save_file = './model_store/detections_1588751332.pkl'
333
+ gen_rnet_sample_data(data_dir, anno_file, save_file, prefix_path)
334
+
335
+
336
+
337
+ def gen_rnet_sample_data(data_dir, anno_file, det_boxs_file, prefix_path):
338
+
339
+ """
340
+ :param data_dir:
341
+ :param anno_file: original annotations file of wider face data
342
+ :param det_boxs_file: detection boxes file
343
+ :param prefix_path:
344
+ :return:
345
+ """
346
+
347
+ neg_save_dir = os.path.join(data_dir, "24/negative")
348
+ pos_save_dir = os.path.join(data_dir, "24/positive")
349
+ part_save_dir = os.path.join(data_dir, "24/part")
350
+
351
+
352
+ for dir_path in [neg_save_dir, pos_save_dir, part_save_dir]:
353
+ # print(dir_path)
354
+ if not os.path.exists(dir_path):
355
+ os.makedirs(dir_path)
356
+
357
+
358
+ # load ground truth from annotation file
359
+ # format of each line: image/path [x1,y1,x2,y2] for each gt_box in this image
360
+
361
+ with open(anno_file, 'r') as f:
362
+ annotations = f.readlines()
363
+
364
+ image_size = 24
365
+ net = "rnet"
366
+
367
+ im_idx_list = list()
368
+ gt_boxes_list = list()
369
+ num_of_images = len(annotations)
370
+ print ("processing %d images in total" % num_of_images)
371
+
372
+ for annotation in annotations:
373
+ annotation = annotation.strip().split(' ')
374
+ im_idx = os.path.join(prefix_path, annotation[0])
375
+ # im_idx = annotation[0]
376
+
377
+ boxes = list(map(float, annotation[1:]))
378
+ boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
379
+ im_idx_list.append(im_idx)
380
+ gt_boxes_list.append(boxes)
381
+
382
+
383
+ # './anno_store'
384
+ save_path = './anno_store'
385
+ if not os.path.exists(save_path):
386
+ os.makedirs(save_path)
387
+
388
+ f1 = open(os.path.join(save_path, 'pos_%d.txt' % image_size), 'w')
389
+ f2 = open(os.path.join(save_path, 'neg_%d.txt' % image_size), 'w')
390
+ f3 = open(os.path.join(save_path, 'part_%d.txt' % image_size), 'w')
391
+
392
+ # print(det_boxs_file)
393
+ det_handle = open(det_boxs_file, 'rb')
394
+
395
+ det_boxes = cPickle.load(det_handle)
396
+
397
+ # an image contain many boxes stored in an array
398
+ print(len(det_boxes), num_of_images)
399
+ # assert len(det_boxes) == num_of_images, "incorrect detections or ground truths"
400
+
401
+ # index of neg, pos and part face, used as their image names
402
+ n_idx = 0
403
+ p_idx = 0
404
+ d_idx = 0
405
+ image_done = 0
406
+ for im_idx, dets, gts in zip(im_idx_list, det_boxes, gt_boxes_list):
407
+
408
+ # if (im_idx+1) == 100:
409
+ # break
410
+
411
+ gts = np.array(gts, dtype=np.float32).reshape(-1, 4)
412
+ if gts.shape[0]==0:
413
+ continue
414
+ if image_done % 100 == 0:
415
+ print("%d images done" % image_done)
416
+ image_done += 1
417
+
418
+ if dets.shape[0] == 0:
419
+ continue
420
+ img = cv2.imread(im_idx)
421
+ # change to square
422
+ dets = convert_to_square(dets)
423
+ dets[:, 0:4] = np.round(dets[:, 0:4])
424
+ neg_num = 0
425
+ for box in dets:
426
+ x_left, y_top, x_right, y_bottom, _ = box.astype(int)
427
+ width = x_right - x_left + 1
428
+ height = y_bottom - y_top + 1
429
+
430
+ # ignore box that is too small or beyond image border
431
+ if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1:
432
+ continue
433
+
434
+ # compute intersection over union(IoU) between current box and all gt boxes
435
+ Iou = IoU(box, gts)
436
+ cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :]
437
+ resized_im = cv2.resize(cropped_im, (image_size, image_size),
438
+ interpolation=cv2.INTER_LINEAR)
439
+
440
+ # save negative images and write label
441
+ # Iou with all gts must below 0.3
442
+ if np.max(Iou) < 0.3 and neg_num < 60:
443
+ # save the examples
444
+ save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx)
445
+ # print(save_file)
446
+ f2.write(save_file + ' 0\n')
447
+ cv2.imwrite(save_file, resized_im)
448
+ n_idx += 1
449
+ neg_num += 1
450
+ else:
451
+ # find gt_box with the highest iou
452
+ idx = np.argmax(Iou)
453
+ assigned_gt = gts[idx]
454
+ x1, y1, x2, y2 = assigned_gt
455
+
456
+ # compute bbox reg label
457
+ offset_x1 = (x1 - x_left) / float(width)
458
+ offset_y1 = (y1 - y_top) / float(height)
459
+ offset_x2 = (x2 - x_right) / float(width)
460
+ offset_y2 = (y2 - y_bottom) / float(height)
461
+
462
+ # save positive and part-face images and write labels
463
+ if np.max(Iou) >= 0.65:
464
+ save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx)
465
+ f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % (
466
+ offset_x1, offset_y1, offset_x2, offset_y2))
467
+ cv2.imwrite(save_file, resized_im)
468
+ p_idx += 1
469
+
470
+ elif np.max(Iou) >= 0.4:
471
+ save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx)
472
+ f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % (
473
+ offset_x1, offset_y1, offset_x2, offset_y2))
474
+ cv2.imwrite(save_file, resized_im)
475
+ d_idx += 1
476
+ f1.close()
477
+ f2.close()
478
+ f3.close()
479
+
480
+ def model_store_path():
481
+ return os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))+"/model_store"
482
+
483
+ def get_Rnet_data(pnet_model):
484
+ gen_rnet_data(traindata_store, annotation_file, pnet_model_file = pnet_model, prefix_path = prefix_path, use_cuda = True)
485
+
486
+
487
+ def assembel_Rnet_data():
488
+ anno_list = []
489
+
490
+ anno_list.append(rnet_postive_file)
491
+ anno_list.append(rnet_part_file)
492
+ anno_list.append(rnet_neg_file)
493
+ # anno_list.append(pnet_landmark_file)
494
+
495
+ chose_count = assemble_data(imglist_filename_rnet ,anno_list)
496
+ print("RNet train annotation result file path:%s" % imglist_filename_rnet)
497
+ #-----------------------------------------------------------------------------------------------------------------------------------------------#
498
+ def gen_onet_data(data_dir, anno_file, pnet_model_file, rnet_model_file, prefix_path='', use_cuda=True, vis=False):
499
+
500
+
501
+ pnet, rnet, _ = create_mtcnn_net(p_model_path=pnet_model_file, r_model_path=rnet_model_file, use_cuda=use_cuda)
502
+ mtcnn_detector = MtcnnDetector(pnet=pnet, rnet=rnet, min_face_size=12)
503
+
504
+ imagedb = ImageDB(anno_file,mode="test",prefix_path=prefix_path)
505
+ imdb = imagedb.load_imdb()
506
+ image_reader = TestImageLoader(imdb,1,False)
507
+
508
+ all_boxes = list()
509
+ batch_idx = 0
510
+
511
+ print('size:%d' % image_reader.size)
512
+ for databatch in image_reader:
513
+ if batch_idx % 50 == 0:
514
+ print("%d images done" % batch_idx)
515
+
516
+ im = databatch
517
+
518
+ t = time.time()
519
+
520
+ # pnet detection = [x1, y1, x2, y2, score, reg]
521
+ p_boxes, p_boxes_align = mtcnn_detector.detect_pnet(im=im)
522
+
523
+ t0 = time.time() - t
524
+ t = time.time()
525
+ # rnet detection
526
+ boxes, boxes_align = mtcnn_detector.detect_rnet(im=im, dets=p_boxes_align)
527
+
528
+ t1 = time.time() - t
529
+ print('cost time pnet--',t0,' rnet--',t1)
530
+ t = time.time()
531
+
532
+ if boxes_align is None:
533
+ all_boxes.append(np.array([]))
534
+ batch_idx += 1
535
+ continue
536
+ if vis:
537
+ rgb_im = cv2.cvtColor(np.asarray(im), cv2.COLOR_BGR2RGB)
538
+ vision.vis_two(rgb_im, boxes, boxes_align)
539
+
540
+
541
+ all_boxes.append(boxes_align)
542
+ batch_idx += 1
543
+
544
+ save_path = './model_store'
545
+
546
+ if not os.path.exists(save_path):
547
+ os.mkdir(save_path)
548
+
549
+ save_file = os.path.join(save_path, "detections_%d.pkl" % int(time.time()))
550
+ with open(save_file, 'wb') as f:
551
+ cPickle.dump(all_boxes, f, cPickle.HIGHEST_PROTOCOL)
552
+
553
+
554
+ gen_onet_sample_data(data_dir,anno_file,save_file,prefix_path)
555
+
556
+
557
+
558
+ def gen_onet_sample_data(data_dir,anno_file,det_boxs_file,prefix):
559
+
560
+ neg_save_dir = os.path.join(data_dir, "48/negative")
561
+ pos_save_dir = os.path.join(data_dir, "48/positive")
562
+ part_save_dir = os.path.join(data_dir, "48/part")
563
+
564
+ for dir_path in [neg_save_dir, pos_save_dir, part_save_dir]:
565
+ if not os.path.exists(dir_path):
566
+ os.makedirs(dir_path)
567
+
568
+
569
+ # load ground truth from annotation file
570
+ # format of each line: image/path [x1,y1,x2,y2] for each gt_box in this image
571
+
572
+ with open(anno_file, 'r') as f:
573
+ annotations = f.readlines()
574
+
575
+ image_size = 48
576
+ net = "onet"
577
+
578
+ im_idx_list = list()
579
+ gt_boxes_list = list()
580
+ num_of_images = len(annotations)
581
+ print("processing %d images in total" % num_of_images)
582
+
583
+ for annotation in annotations:
584
+ annotation = annotation.strip().split(' ')
585
+ im_idx = os.path.join(prefix,annotation[0])
586
+
587
+ boxes = list(map(float, annotation[1:]))
588
+ boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4)
589
+ im_idx_list.append(im_idx)
590
+ gt_boxes_list.append(boxes)
591
+
592
+ save_path = './anno_store'
593
+ if not os.path.exists(save_path):
594
+ os.makedirs(save_path)
595
+
596
+ f1 = open(os.path.join(save_path, 'pos_%d.txt' % image_size), 'w')
597
+ f2 = open(os.path.join(save_path, 'neg_%d.txt' % image_size), 'w')
598
+ f3 = open(os.path.join(save_path, 'part_%d.txt' % image_size), 'w')
599
+
600
+ det_handle = open(det_boxs_file, 'rb')
601
+
602
+ det_boxes = cPickle.load(det_handle)
603
+ print(len(det_boxes), num_of_images)
604
+ # assert len(det_boxes) == num_of_images, "incorrect detections or ground truths"
605
+
606
+ # index of neg, pos and part face, used as their image names
607
+ n_idx = 0
608
+ p_idx = 0
609
+ d_idx = 0
610
+ image_done = 0
611
+ for im_idx, dets, gts in zip(im_idx_list, det_boxes, gt_boxes_list):
612
+ if image_done % 100 == 0:
613
+ print("%d images done" % image_done)
614
+ image_done += 1
615
+ if gts.shape[0]==0:
616
+ continue
617
+ if dets.shape[0] == 0:
618
+ continue
619
+ img = cv2.imread(im_idx)
620
+ dets = convert_to_square(dets)
621
+ dets[:, 0:4] = np.round(dets[:, 0:4])
622
+
623
+ for box in dets:
624
+ x_left, y_top, x_right, y_bottom = box[0:4].astype(int)
625
+ width = x_right - x_left + 1
626
+ height = y_bottom - y_top + 1
627
+
628
+ # ignore box that is too small or beyond image border
629
+ if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1:
630
+ continue
631
+
632
+ # compute intersection over union(IoU) between current box and all gt boxes
633
+ Iou = IoU(box, gts)
634
+ cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :]
635
+ resized_im = cv2.resize(cropped_im, (image_size, image_size),
636
+ interpolation=cv2.INTER_LINEAR)
637
+
638
+ # save negative images and write label
639
+ if np.max(Iou) < 0.3:
640
+ # Iou with all gts must below 0.3
641
+ save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx)
642
+ f2.write(save_file + ' 0\n')
643
+ cv2.imwrite(save_file, resized_im)
644
+ n_idx += 1
645
+ else:
646
+ # find gt_box with the highest iou
647
+ idx = np.argmax(Iou)
648
+ assigned_gt = gts[idx]
649
+ x1, y1, x2, y2 = assigned_gt
650
+
651
+ # compute bbox reg label
652
+ offset_x1 = (x1 - x_left) / float(width)
653
+ offset_y1 = (y1 - y_top) / float(height)
654
+ offset_x2 = (x2 - x_right) / float(width)
655
+ offset_y2 = (y2 - y_bottom) / float(height)
656
+
657
+ # save positive and part-face images and write labels
658
+ if np.max(Iou) >= 0.65:
659
+ save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx)
660
+ f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % (
661
+ offset_x1, offset_y1, offset_x2, offset_y2))
662
+ cv2.imwrite(save_file, resized_im)
663
+ p_idx += 1
664
+
665
+ elif np.max(Iou) >= 0.4:
666
+ save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx)
667
+ f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % (
668
+ offset_x1, offset_y1, offset_x2, offset_y2))
669
+ cv2.imwrite(save_file, resized_im)
670
+ d_idx += 1
671
+ f1.close()
672
+ f2.close()
673
+ f3.close()
674
+
675
+
676
+
677
+ def model_store_path():
678
+ return os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))+"/model_store"
679
+
680
+
681
+ def get_Onet_data(pnet_model, rnet_model):
682
+ gen_onet_data(traindata_store, annotation_file, pnet_model_file = pnet_model, rnet_model_file = rnet_model,prefix_path=prefix_path,use_cuda = True, vis = False)
683
+
684
+
685
+ def assembel_Onet_data():
686
+ anno_list = []
687
+
688
+ anno_list.append(onet_postive_file)
689
+ anno_list.append(onet_part_file)
690
+ anno_list.append(onet_neg_file)
691
+ anno_list.append(onet_landmark_file)
692
+
693
+ chose_count = assemble_data(imglist_filename_onet ,anno_list)
694
+ print("ONet train annotation result file path:%s" % imglist_filename_onet)
695
+
696
+
697
+ def gen_landmark_48(anno_file, data_dir, prefix = ''):
698
+
699
+
700
+ size = 48
701
+ image_id = 0
702
+
703
+ landmark_imgs_save_dir = os.path.join(data_dir,"48/landmark")
704
+ if not os.path.exists(landmark_imgs_save_dir):
705
+ os.makedirs(landmark_imgs_save_dir)
706
+
707
+ anno_dir = './anno_store'
708
+ if not os.path.exists(anno_dir):
709
+ os.makedirs(anno_dir)
710
+
711
+ landmark_anno_filename = "landmark_48.txt"
712
+ save_landmark_anno = os.path.join(anno_dir,landmark_anno_filename)
713
+
714
+ # print(save_landmark_anno)
715
+ # time.sleep(5)
716
+ f = open(save_landmark_anno, 'w')
717
+ # dstdir = "train_landmark_few"
718
+
719
+ with open(anno_file, 'r') as f2:
720
+ annotations = f2.readlines()
721
+
722
+ num = len(annotations)
723
+ print("%d total images" % num)
724
+
725
+ l_idx =0
726
+ idx = 0
727
+ # image_path bbox landmark(5*2)
728
+ for annotation in annotations:
729
+ # print imgPath
730
+
731
+ annotation = annotation.strip().split(' ')
732
+
733
+ assert len(annotation)==15,"each line should have 15 element"
734
+
735
+ im_path = os.path.join('./data_set/face_landmark/CNN_FacePoint/train/',annotation[0].replace("\\", "/"))
736
+
737
+ gt_box = list(map(float, annotation[1:5]))
738
+ # gt_box = [gt_box[0], gt_box[2], gt_box[1], gt_box[3]]
739
+
740
+
741
+ gt_box = np.array(gt_box, dtype=np.int32)
742
+
743
+ landmark = list(map(float, annotation[5:]))
744
+ landmark = np.array(landmark, dtype=np.float)
745
+
746
+ img = cv2.imread(im_path)
747
+ # print(im_path)
748
+ assert (img is not None)
749
+
750
+ height, width, channel = img.shape
751
+ # crop_face = img[gt_box[1]:gt_box[3]+1, gt_box[0]:gt_box[2]+1]
752
+ # crop_face = cv2.resize(crop_face,(size,size))
753
+
754
+ idx = idx + 1
755
+ if idx % 100 == 0:
756
+ print("%d images done, landmark images: %d"%(idx,l_idx))
757
+ # print(im_path)
758
+ # print(gt_box)
759
+ x1, x2, y1, y2 = gt_box
760
+ gt_box[1] = y1
761
+ gt_box[2] = x2
762
+ # time.sleep(5)
763
+
764
+ # gt's width
765
+ w = x2 - x1 + 1
766
+ # gt's height
767
+ h = y2 - y1 + 1
768
+ if max(w, h) < 40 or x1 < 0 or y1 < 0:
769
+ continue
770
+ # random shift
771
+ for i in range(10):
772
+ bbox_size = np.random.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h)))
773
+ delta_x = np.random.randint(-w * 0.2, w * 0.2)
774
+ delta_y = np.random.randint(-h * 0.2, h * 0.2)
775
+ nx1 = max(x1 + w / 2 - bbox_size / 2 + delta_x, 0)
776
+ ny1 = max(y1 + h / 2 - bbox_size / 2 + delta_y, 0)
777
+
778
+ nx2 = nx1 + bbox_size
779
+ ny2 = ny1 + bbox_size
780
+ if nx2 > width or ny2 > height:
781
+ continue
782
+ crop_box = np.array([nx1, ny1, nx2, ny2])
783
+ cropped_im = img[int(ny1):int(ny2) + 1, int(nx1):int(nx2) + 1, :]
784
+ resized_im = cv2.resize(cropped_im, (size, size),interpolation=cv2.INTER_LINEAR)
785
+
786
+ offset_x1 = (x1 - nx1) / float(bbox_size)
787
+ offset_y1 = (y1 - ny1) / float(bbox_size)
788
+ offset_x2 = (x2 - nx2) / float(bbox_size)
789
+ offset_y2 = (y2 - ny2) / float(bbox_size)
790
+
791
+ offset_left_eye_x = (landmark[0] - nx1) / float(bbox_size)
792
+ offset_left_eye_y = (landmark[1] - ny1) / float(bbox_size)
793
+
794
+ offset_right_eye_x = (landmark[2] - nx1) / float(bbox_size)
795
+ offset_right_eye_y = (landmark[3] - ny1) / float(bbox_size)
796
+
797
+ offset_nose_x = (landmark[4] - nx1) / float(bbox_size)
798
+ offset_nose_y = (landmark[5] - ny1) / float(bbox_size)
799
+
800
+ offset_left_mouth_x = (landmark[6] - nx1) / float(bbox_size)
801
+ offset_left_mouth_y = (landmark[7] - ny1) / float(bbox_size)
802
+
803
+ offset_right_mouth_x = (landmark[8] - nx1) / float(bbox_size)
804
+ offset_right_mouth_y = (landmark[9] - ny1) / float(bbox_size)
805
+
806
+
807
+ # cal iou
808
+ iou = IoU(crop_box.astype(np.float), np.expand_dims(gt_box.astype(np.float), 0))
809
+ # print(iou)
810
+ if iou > 0.65:
811
+ save_file = os.path.join(landmark_imgs_save_dir, "%s.jpg" % l_idx)
812
+ cv2.imwrite(save_file, resized_im)
813
+
814
+ f.write(save_file + ' -2 %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n' % \
815
+ (offset_x1, offset_y1, offset_x2, offset_y2, \
816
+ offset_left_eye_x,offset_left_eye_y,offset_right_eye_x,offset_right_eye_y,offset_nose_x,offset_nose_y,offset_left_mouth_x,offset_left_mouth_y,offset_right_mouth_x,offset_right_mouth_y))
817
+ # print(save_file)
818
+ # print(save_landmark_anno)
819
+ l_idx += 1
820
+
821
+ f.close()
822
+
823
+
824
+ def parse_args():
825
+ parser = argparse.ArgumentParser(description='Get data',
826
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
827
+
828
+ parser.add_argument('--net', dest='net', help='which net to show', type=str)
829
+ parser.add_argument('--pnet_path', default="./model_store/pnet_epoch_20.pt",help='path to pnet model', type=str)
830
+ parser.add_argument('--rnet_path', default="./model_store/rnet_epoch_20.pt",help='path to rnet model', type=str)
831
+ parser.add_argument('--use_cuda', default=True,help='use cuda', type=bool)
832
+
833
+ args = parser.parse_args()
834
+ return args
835
+
836
+ #-----------------------------------------------------------------------------------------------------------------------------------------------#
837
+ if __name__ == '__main__':
838
+ args = parse_args()
839
+ dir = 'anno_store'
840
+ if not os.path.exists(dir):
841
+ os.makedirs(dir)
842
+ if args.net == "pnet":
843
+ wider_face(txt_from_path, anno_file)
844
+ get_Pnet_data()
845
+ assembel_Pnet_data()
846
+ elif args.net == "rnet":
847
+ get_Rnet_data(args.pnet_path)
848
+ assembel_Rnet_data()
849
+ elif args.net == "onet":
850
+ get_Onet_data(args.pnet_path, args.rnet_path)
851
+ gen_landmark_48(annotation_file_lm, traindata_store, prefix_path_lm)
852
+ assembel_Onet_data()
img/mid.png ADDED
img/onet.png ADDED
img/pnet.png ADDED
img/result.png ADDED
img/rnet.png ADDED
model_store/onet_epoch_20.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53e8fe6d59c0b3cd75ae24f37756e056e05b9fa555cd9e442543aef54cc5f887
3
+ size 903910
model_store/pnet_epoch_20.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e818bafbe694390fba4cf59cad9d67a04ed8fb9297e5b4032c3d2af3832e5365
3
+ size 32056
model_store/rnet_epoch_20.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfe5d5abf979cb3d7eda838d9d6c8e1b582e4a53a1d20e9b6ff54953ed3ba042
3
+ size 245871
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib==3.5.0
2
+ matplotlib-inline==0.1.3
3
+ numpy==1.21.4
4
+ opencv-python==4.4.0.42
5
+ opencv-python-headless==4.6.0.66
6
+ Pillow==9.1.1
7
+ scikit-image==0.19.3
8
+ torch==1.13.0+cu116
9
+ torchaudio==0.13.0+cu116
10
+ torchvision==0.14.0+cu116
test.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from utils.detect import create_mtcnn_net, MtcnnDetector
3
+ from utils.vision import vis_face
4
+ import argparse
5
+
6
+
7
+ MIN_FACE_SIZE = 3
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description='Test MTCNN',
11
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
+
13
+ parser.add_argument('--net', default='onet', help='which net to show', type=str)
14
+ parser.add_argument('--pnet_path', default="./model_store/pnet_epoch_20.pt",help='path to pnet model', type=str)
15
+ parser.add_argument('--rnet_path', default="./model_store/rnet_epoch_20.pt",help='path to rnet model', type=str)
16
+ parser.add_argument('--onet_path', default="./model_store/onet_epoch_20.pt",help='path to onet model', type=str)
17
+ parser.add_argument('--path', default="./img/mid.png",help='path to image', type=str)
18
+ parser.add_argument('--min_face_size', default=MIN_FACE_SIZE,help='min face size', type=int)
19
+ parser.add_argument('--use_cuda', default=False,help='use cuda', type=bool)
20
+ parser.add_argument('--thresh', default='[0.1, 0.1, 0.1]',help='thresh', type=str)
21
+ parser.add_argument('--save_name', default="result.jpg",help='save name', type=str)
22
+ parser.add_argument('--input_mode', default=1,help='image or video', type=int)
23
+ args = parser.parse_args()
24
+ return args
25
+ if __name__ == '__main__':
26
+ args = parse_args()
27
+ thresh = [float(i) for i in (args.thresh).split('[')[1].split(']')[0].split(',')]
28
+ pnet, rnet, onet = create_mtcnn_net(p_model_path=args.pnet_path, r_model_path=args.rnet_path,o_model_path=args.onet_path, use_cuda=args.use_cuda)
29
+ mtcnn_detector = MtcnnDetector(pnet=pnet, rnet=rnet, onet=onet, min_face_size=args.min_face_size,threshold=thresh)
30
+ if args.input_mode == 1:
31
+ img = cv2.imread(args.path)
32
+ img_bg = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
33
+ p_bboxs, r_bboxs, bboxs, landmarks = mtcnn_detector.detect_face(img)
34
+ # print box_align
35
+ save_name = args.save_name
36
+ if args.net == 'pnet':
37
+ vis_face(img_bg, p_bboxs, landmarks, MIN_FACE_SIZE, save_name)
38
+ elif args.net == 'rnet':
39
+ vis_face(img_bg, r_bboxs, landmarks, MIN_FACE_SIZE, save_name)
40
+ elif args.net == 'onet':
41
+ vis_face(img_bg, bboxs, landmarks, MIN_FACE_SIZE, save_name)
42
+ elif args.input_mode == 0:
43
+ cap=cv2.VideoCapture(0)
44
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
45
+ out = cv2.VideoWriter('out.mp4' ,fourcc,10,(640,480))
46
+ while True:
47
+ t1=cv2.getTickCount()
48
+ ret,frame = cap.read()
49
+ if ret == True:
50
+ boxes_c,landmarks = mtcnn_detector.detect_face(frame)
51
+ t2=cv2.getTickCount()
52
+ t=(t2-t1)/cv2.getTickFrequency()
53
+ fps=1.0/t
54
+ for i in range(boxes_c.shape[0]):
55
+ bbox = boxes_c[i, :4]
56
+ score = boxes_c[i, 4]
57
+ corpbbox = [int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])]
58
+
59
+ #画人脸框
60
+ cv2.rectangle(frame, (corpbbox[0], corpbbox[1]),
61
+ (corpbbox[2], corpbbox[3]), (255, 0, 0), 1)
62
+ #画置信度
63
+ cv2.putText(frame, '{:.2f}'.format(score),
64
+ (corpbbox[0], corpbbox[1] - 2),
65
+ cv2.FONT_HERSHEY_SIMPLEX,
66
+ 0.5,(0, 0, 255), 2)
67
+ #画fps值
68
+ cv2.putText(frame, '{:.4f}'.format(t) + " " + '{:.3f}'.format(fps), (10, 20),
69
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 2)
70
+ #画关键点
71
+ for i in range(landmarks.shape[0]):
72
+ for j in range(len(landmarks[i])//2):
73
+ cv2.circle(frame, (int(landmarks[i][2*j]),int(int(landmarks[i][2*j+1]))), 2, (0,0,255))
74
+ a = out.write(frame)
75
+ cv2.imshow("result", frame)
76
+ if cv2.waitKey(1) & 0xFF == ord('q'):
77
+ break
78
+ else:
79
+ break
80
+ cap.release()
81
+ out.release()
82
+ cv2.destroyAllWindows()
83
+
84
+
test.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python test.py --net=pnet --min_face_size=1 --pnet_path=./model_store/pnet_epoch_20.pt --rnet_path=./model_store/rnet_epoch_20.pt --onet_path=./model_store/onet_epoch_20.pt --save_name=pnet
2
+ python test.py --net=rnet --min_face_size=1 --pnet_path=./model_store/pnet_epoch_20.pt --rnet_path=./model_store/rnet_epoch_20.pt --onet_path=./model_store/onet_epoch_20.pt --save_name=rnet
3
+ python test.py --net=onet --min_face_size=1 --pnet_path=./model_store/pnet_epoch_20.pt --rnet_path=./model_store/rnet_epoch_20.pt --onet_path=./model_store/onet_epoch_20.pt --save_name=onet
4
+ echo "Testing finished!"
train.out ADDED
The diff for this file is too large to render. See raw diff
 
train.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.dataloader import TrainImageReader,convert_image_to_tensor,ImageDB
2
+ import datetime
3
+ import os
4
+ from utils.models import PNet,RNet,ONet,LossFn
5
+ import torch
6
+ #from torch.autograd import Variable 新版本中已弃用
7
+ import utils.config as config
8
+ import argparse
9
+ import sys
10
+ sys.path.append(os.getcwd())
11
+ import numpy as np
12
+
13
+
14
+
15
+ def compute_accuracy(prob_cls, gt_cls):
16
+
17
+ prob_cls = torch.squeeze(prob_cls)
18
+ gt_cls = torch.squeeze(gt_cls)
19
+
20
+ #we only need the detection which >= 0
21
+ mask = torch.ge(gt_cls,0)
22
+ #get valid element
23
+ valid_gt_cls = torch.masked_select(gt_cls,mask)
24
+ valid_prob_cls = torch.masked_select(prob_cls,mask)
25
+ size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0])
26
+ prob_ones = torch.ge(valid_prob_cls,0.6).float()
27
+ right_ones = torch.eq(prob_ones,valid_gt_cls).float()
28
+
29
+ ## if size == 0 meaning that your gt_labels are all negative, landmark or part
30
+
31
+ return torch.div(torch.mul(torch.sum(right_ones),float(1.0)),float(size)) ## divided by zero meaning that your gt_labels are all negative, landmark or part
32
+
33
+
34
+ def train_pnet(model_store_path, end_epoch,imdb,
35
+ batch_size,frequent=10,base_lr=0.01,lr_epoch_decay=[9],use_cuda=True,load=''):
36
+
37
+ #create lr_list
38
+ lr_epoch_decay.append(end_epoch+1)
39
+ lr_list = np.zeros(end_epoch)
40
+ lr_t = base_lr
41
+ for i in range(len(lr_epoch_decay)):
42
+ if i==0:
43
+ lr_list[0:lr_epoch_decay[i]-1]=lr_t
44
+ else:
45
+ lr_list[lr_epoch_decay[i-1]-1:lr_epoch_decay[i]-1]=lr_t
46
+ lr_t*=0.1
47
+
48
+
49
+ if not os.path.exists(model_store_path):
50
+ os.makedirs(model_store_path)
51
+
52
+ lossfn = LossFn()
53
+ net = PNet(is_train=True, use_cuda=use_cuda)
54
+ if load!='':
55
+ net.load_state_dict(torch.load(load))
56
+ print('model loaded',load)
57
+ net.train()
58
+
59
+ if use_cuda:
60
+ net.cuda()
61
+
62
+
63
+ optimizer = torch.optim.Adam(net.parameters(), lr=lr_list[0])
64
+ #optimizer = torch.optim.SGD(net.parameters(), lr=lr_list[0])
65
+
66
+ train_data=TrainImageReader(imdb,12,batch_size,shuffle=True)
67
+
68
+ #frequent = 10
69
+ for cur_epoch in range(1,end_epoch+1):
70
+ train_data.reset() # shuffle
71
+ for param in optimizer.param_groups:
72
+ param['lr'] = lr_list[cur_epoch-1]
73
+ for batch_idx,(image,(gt_label,gt_bbox,gt_landmark))in enumerate(train_data):
74
+
75
+ im_tensor = [ convert_image_to_tensor(image[i,:,:,:]) for i in range(image.shape[0]) ]
76
+ im_tensor = torch.stack(im_tensor)
77
+
78
+ im_tensor.requires_grad = True
79
+ gt_label = torch.from_numpy(gt_label).float()
80
+ gt_label.requires_grad = True
81
+
82
+ gt_bbox = torch.from_numpy(gt_bbox).float()
83
+ gt_bbox.requires_grad = True
84
+ # gt_landmark = Variable(torch.from_numpy(gt_landmark).float())
85
+
86
+ if use_cuda:
87
+ im_tensor = im_tensor.cuda()
88
+ gt_label = gt_label.cuda()
89
+ gt_bbox = gt_bbox.cuda()
90
+ # gt_landmark = gt_landmark.cuda()
91
+
92
+ cls_pred, box_offset_pred = net(im_tensor)
93
+ # all_loss, cls_loss, offset_loss = lossfn.loss(gt_label=label_y,gt_offset=bbox_y, pred_label=cls_pred, pred_offset=box_offset_pred)
94
+
95
+ cls_loss = lossfn.cls_loss(gt_label,cls_pred)
96
+ box_offset_loss = lossfn.box_loss(gt_label,gt_bbox,box_offset_pred)
97
+ # landmark_loss = lossfn.landmark_loss(gt_label,gt_landmark,landmark_offset_pred)
98
+
99
+ all_loss = cls_loss*1.0+box_offset_loss*0.5
100
+
101
+ if batch_idx %frequent==0:
102
+ accuracy=compute_accuracy(cls_pred,gt_label)
103
+
104
+ show1 = accuracy.data.cpu().numpy()
105
+ show2 = cls_loss.data.cpu().numpy()
106
+ show3 = box_offset_loss.data.cpu().numpy()
107
+ # show4 = landmark_loss.data.cpu().numpy()
108
+ show5 = all_loss.data.cpu().numpy()
109
+
110
+ print("%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, all_loss: %s, lr:%s "%(datetime.datetime.now(),cur_epoch,batch_idx, show1,show2,show3,show5,lr_list[cur_epoch-1]))
111
+
112
+ optimizer.zero_grad()
113
+ all_loss.backward()
114
+ optimizer.step()
115
+
116
+ torch.save(net.state_dict(), os.path.join(model_store_path,"pnet_epoch_%d.pt" % cur_epoch))
117
+ torch.save(net, os.path.join(model_store_path,"pnet_epoch_model_%d.pkl" % cur_epoch))
118
+
119
+
120
+
121
+
122
+ def train_rnet(model_store_path, end_epoch,imdb,
123
+ batch_size,frequent=50,base_lr=0.01,lr_epoch_decay=[9],use_cuda=True,load=''):
124
+
125
+ #create lr_list
126
+ lr_epoch_decay.append(end_epoch+1)
127
+ lr_list = np.zeros(end_epoch)
128
+ lr_t = base_lr
129
+ for i in range(len(lr_epoch_decay)):
130
+ if i==0:
131
+ lr_list[0:lr_epoch_decay[i]-1]=lr_t
132
+ else:
133
+ lr_list[lr_epoch_decay[i-1]-1:lr_epoch_decay[i]-1]=lr_t
134
+ lr_t*=0.1
135
+ #print(lr_list)
136
+ if not os.path.exists(model_store_path):
137
+ os.makedirs(model_store_path)
138
+
139
+ lossfn = LossFn()
140
+ net = RNet(is_train=True, use_cuda=use_cuda)
141
+ net.train()
142
+ if load!='':
143
+ net.load_state_dict(torch.load(load))
144
+ print('model loaded',load)
145
+ if use_cuda:
146
+ net.cuda()
147
+
148
+
149
+ optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)
150
+
151
+ train_data=TrainImageReader(imdb,24,batch_size,shuffle=True)
152
+
153
+
154
+ for cur_epoch in range(1,end_epoch+1):
155
+ train_data.reset()
156
+ for param in optimizer.param_groups:
157
+ param['lr'] = lr_list[cur_epoch-1]
158
+
159
+ for batch_idx,(image,(gt_label,gt_bbox,gt_landmark))in enumerate(train_data):
160
+
161
+ im_tensor = [ convert_image_to_tensor(image[i,:,:,:]) for i in range(image.shape[0]) ]
162
+ im_tensor = torch.stack(im_tensor)
163
+
164
+ im_tensor.requires_grad = True
165
+ gt_label = torch.from_numpy(gt_label).float()
166
+ gt_label.requires_grad = True
167
+
168
+ gt_bbox = torch.from_numpy(gt_bbox).float()
169
+ gt_bbox.requires_grad = True
170
+ gt_landmark = torch.from_numpy(gt_landmark).float()
171
+ gt_landmark.requires_grad = True
172
+
173
+ if use_cuda:
174
+ im_tensor = im_tensor.cuda()
175
+ gt_label = gt_label.cuda()
176
+ gt_bbox = gt_bbox.cuda()
177
+ gt_landmark = gt_landmark.cuda()
178
+
179
+ cls_pred, box_offset_pred = net(im_tensor)
180
+ # all_loss, cls_loss, offset_loss = lossfn.loss(gt_label=label_y,gt_offset=bbox_y, pred_label=cls_pred, pred_offset=box_offset_pred)
181
+
182
+ cls_loss = lossfn.cls_loss(gt_label,cls_pred)
183
+ box_offset_loss = lossfn.box_loss(gt_label,gt_bbox,box_offset_pred)
184
+ # landmark_loss = lossfn.landmark_loss(gt_label,gt_landmark,landmark_offset_pred)
185
+
186
+ all_loss = cls_loss*1.0+box_offset_loss*0.5
187
+
188
+ if batch_idx%frequent==0:
189
+ accuracy=compute_accuracy(cls_pred,gt_label)
190
+
191
+ show1 = accuracy.data.cpu().numpy()
192
+ show2 = cls_loss.data.cpu().numpy()
193
+ show3 = box_offset_loss.data.cpu().numpy()
194
+ # show4 = landmark_loss.data.cpu().numpy()
195
+ show5 = all_loss.data.cpu().numpy()
196
+
197
+ print("%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, all_loss: %s, lr:%s "%(datetime.datetime.now(), cur_epoch, batch_idx, show1, show2, show3, show5, lr_list[cur_epoch-1]))
198
+
199
+ optimizer.zero_grad()
200
+ all_loss.backward()
201
+ optimizer.step()
202
+
203
+ torch.save(net.state_dict(), os.path.join(model_store_path,"rnet_epoch_%d.pt" % cur_epoch))
204
+ torch.save(net, os.path.join(model_store_path,"rnet_epoch_model_%d.pkl" % cur_epoch))
205
+
206
+
207
+ def train_onet(model_store_path, end_epoch,imdb,
208
+ batch_size,frequent=50,base_lr=0.01,lr_epoch_decay=[9],use_cuda=True,load=''):
209
+ #create lr_list
210
+ lr_epoch_decay.append(end_epoch+1)
211
+ lr_list = np.zeros(end_epoch)
212
+ lr_t = base_lr
213
+ for i in range(len(lr_epoch_decay)):
214
+ if i==0:
215
+ lr_list[0:lr_epoch_decay[i]-1]=lr_t
216
+ else:
217
+ lr_list[lr_epoch_decay[i-1]-1:lr_epoch_decay[i]-1]=lr_t
218
+ lr_t*=0.1
219
+ #print(lr_list)
220
+
221
+ if not os.path.exists(model_store_path):
222
+ os.makedirs(model_store_path)
223
+
224
+ lossfn = LossFn()
225
+ net = ONet(is_train=True)
226
+ if load!='':
227
+ net.load_state_dict(torch.load(load))
228
+ print('model loaded',load)
229
+ net.train()
230
+ #print(use_cuda)
231
+ if use_cuda:
232
+ net.cuda()
233
+
234
+
235
+ optimizer = torch.optim.Adam(net.parameters(), lr=base_lr)
236
+
237
+ train_data=TrainImageReader(imdb,48,batch_size,shuffle=True)
238
+
239
+
240
+ for cur_epoch in range(1,end_epoch+1):
241
+
242
+ train_data.reset()
243
+ for param in optimizer.param_groups:
244
+ param['lr'] = lr_list[cur_epoch-1]
245
+ for batch_idx,(image,(gt_label,gt_bbox,gt_landmark))in enumerate(train_data):
246
+ # print("batch id {0}".format(batch_idx))
247
+ im_tensor = [ convert_image_to_tensor(image[i,:,:,:]) for i in range(image.shape[0]) ]
248
+ im_tensor = torch.stack(im_tensor)
249
+
250
+ im_tensor.requires_grad = True
251
+ gt_label = torch.from_numpy(gt_label).float()
252
+ gt_label.requires_grad = True
253
+
254
+ gt_bbox = torch.from_numpy(gt_bbox).float()
255
+ gt_bbox.requires_grad = True
256
+ gt_landmark = torch.from_numpy(gt_landmark).float()
257
+ gt_landmark.requires_grad = True
258
+
259
+ if use_cuda:
260
+ im_tensor = im_tensor.cuda()
261
+ gt_label = gt_label.cuda()
262
+ gt_bbox = gt_bbox.cuda()
263
+ gt_landmark = gt_landmark.cuda()
264
+
265
+ cls_pred, box_offset_pred, landmark_offset_pred = net(im_tensor)
266
+
267
+ # all_loss, cls_loss, offset_loss = lossfn.loss(gt_label=label_y,gt_offset=bbox_y, pred_label=cls_pred, pred_offset=box_offset_pred)
268
+
269
+ cls_loss = lossfn.cls_loss(gt_label,cls_pred)
270
+ box_offset_loss = lossfn.box_loss(gt_label,gt_bbox,box_offset_pred)
271
+ landmark_loss = lossfn.landmark_loss(gt_label,gt_landmark,landmark_offset_pred)
272
+
273
+ all_loss = cls_loss*0.8+box_offset_loss*0.6+landmark_loss*1.5
274
+
275
+ if batch_idx%frequent==0:
276
+ accuracy=compute_accuracy(cls_pred,gt_label)
277
+
278
+ show1 = accuracy.data.cpu().numpy()
279
+ show2 = cls_loss.data.cpu().numpy()
280
+ show3 = box_offset_loss.data.cpu().numpy()
281
+ show4 = landmark_loss.data.cpu().numpy()
282
+ show5 = all_loss.data.cpu().numpy()
283
+
284
+ print("%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, landmark loss: %s, all_loss: %s, lr:%s "%(datetime.datetime.now(),cur_epoch,batch_idx, show1,show2,show3,show4,show5,base_lr))
285
+ #print("%s : Epoch: %d, Step: %d, accuracy: %s, det loss: %s, bbox loss: %s, all_loss: %s, lr:%s "%(datetime.datetime.now(),cur_epoch,batch_idx, show1,show2,show3,show5,lr_list[cur_epoch-1]))
286
+
287
+ optimizer.zero_grad()
288
+ all_loss.backward()
289
+ optimizer.step()
290
+
291
+ torch.save(net.state_dict(), os.path.join(model_store_path,"onet_epoch_%d.pt" % cur_epoch))
292
+ torch.save(net, os.path.join(model_store_path,"onet_epoch_model_%d.pkl" % cur_epoch))
293
+
294
+
295
+
296
+
297
+
298
+
299
+ def parse_args():
300
+ parser = argparse.ArgumentParser(description='Train MTCNN',
301
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
302
+
303
+ parser.add_argument('--net', dest='net', help='which net to train', type=str)
304
+
305
+ parser.add_argument('--anno_file', dest='annotation_file', help='training data annotation file', type=str)
306
+ parser.add_argument('--model_path', dest='model_store_path', help='training model store directory',
307
+ default=config.MODEL_STORE_DIR, type=str)
308
+ parser.add_argument('--end_epoch', dest='end_epoch', help='end epoch of training',
309
+ default=config.END_EPOCH, type=int)
310
+ parser.add_argument('--frequent', dest='frequent', help='frequency of logging',
311
+ default=200, type=int)
312
+ parser.add_argument('--lr', dest='lr', help='learning rate',
313
+ default=config.TRAIN_LR, type=float)
314
+ parser.add_argument('--batch_size', dest='batch_size', help='train batch size',
315
+ default=config.TRAIN_BATCH_SIZE, type=int)
316
+ parser.add_argument('--gpu', dest='use_cuda', help='train with gpu',
317
+ default=config.USE_CUDA, type=bool)
318
+ parser.add_argument('--load', dest='load', help='load model', type=str)
319
+
320
+ args = parser.parse_args()
321
+ return args
322
+
323
+ def train_net(annotation_file, model_store_path,
324
+ end_epoch=16, frequent=200, lr=0.01,lr_epoch_decay=[9],
325
+ batch_size=128, use_cuda=False,load='',net='pnet'):
326
+ if net=='pnet':
327
+ annotation_file = os.path.join(config.ANNO_STORE_DIR,config.PNET_TRAIN_IMGLIST_FILENAME)
328
+ elif net=='rnet':
329
+ annotation_file = os.path.join(config.ANNO_STORE_DIR,config.RNET_TRAIN_IMGLIST_FILENAME)
330
+ elif net=='onet':
331
+ annotation_file = os.path.join(config.ANNO_STORE_DIR,config.ONET_TRAIN_IMGLIST_FILENAME)
332
+ imagedb = ImageDB(annotation_file)
333
+ gt_imdb = imagedb.load_imdb()
334
+ print('DATASIZE',len(gt_imdb))
335
+ gt_imdb = imagedb.append_flipped_images(gt_imdb)
336
+ print('FLIP DATASIZE',len(gt_imdb))
337
+ if net=="pnet":
338
+ print("Training Pnet:")
339
+ train_pnet(model_store_path=model_store_path, end_epoch=end_epoch, imdb=gt_imdb, batch_size=batch_size, frequent=frequent, base_lr=lr,lr_epoch_decay=lr_epoch_decay, use_cuda=use_cuda,load=load)
340
+ elif net=="rnet":
341
+ print("Training Rnet:")
342
+ train_rnet(model_store_path=model_store_path, end_epoch=end_epoch, imdb=gt_imdb, batch_size=batch_size, frequent=frequent, base_lr=lr,lr_epoch_decay=lr_epoch_decay, use_cuda=use_cuda,load=load)
343
+ elif net=="onet":
344
+ print("Training Onet:")
345
+ train_onet(model_store_path=model_store_path, end_epoch=end_epoch, imdb=gt_imdb, batch_size=batch_size, frequent=frequent, base_lr=lr,lr_epoch_decay=lr_epoch_decay, use_cuda=use_cuda,load=load)
346
+
347
+ if __name__ == '__main__':
348
+ args = parse_args()
349
+ lr_epoch_decay = [9]
350
+ train_net(annotation_file=args.annotation_file, model_store_path=args.model_store_path,
351
+ end_epoch=args.end_epoch, frequent=args.frequent, lr=args.lr,lr_epoch_decay=lr_epoch_decay,batch_size=args.batch_size, use_cuda=args.use_cuda,load=args.load,net=args.net)
train.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ python get_data.py --net=pnet
2
+ python train.py --net=pnet
3
+ python get_data.py --net=rnet --pnet_path=./model_store/pnet_epoch_20.pt
4
+ python train.py --net=rnet
5
+ python get_data.py --net=onet --pnet_path=./model_store/pnet_epoch_20.pt --rnet_path=./model_store/rnet_epoch_20.pt
6
+ python train.py --net=onet
7
+ echo "Training finished!"
utils/config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ '''使用示例代码的原始参数'''
3
+
4
+ MODEL_STORE_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))+"/model_store"
5
+
6
+
7
+ ANNO_STORE_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))+"/anno_store"
8
+
9
+
10
+ LOG_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))+"/log"
11
+
12
+
13
+ USE_CUDA = True
14
+
15
+
16
+ TRAIN_BATCH_SIZE = 512
17
+
18
+ TRAIN_LR = 0.01
19
+
20
+ END_EPOCH = 20
21
+
22
+
23
+ PNET_POSTIVE_ANNO_FILENAME = "pos_12.txt"
24
+ PNET_NEGATIVE_ANNO_FILENAME = "neg_12.txt"
25
+ PNET_PART_ANNO_FILENAME = "part_12.txt"
26
+ PNET_LANDMARK_ANNO_FILENAME = "landmark_12.txt"
27
+
28
+
29
+ RNET_POSTIVE_ANNO_FILENAME = "pos_24.txt"
30
+ RNET_NEGATIVE_ANNO_FILENAME = "neg_24.txt"
31
+ RNET_PART_ANNO_FILENAME = "part_24.txt"
32
+ RNET_LANDMARK_ANNO_FILENAME = "landmark_24.txt"
33
+
34
+
35
+ ONET_POSTIVE_ANNO_FILENAME = "pos_48.txt"
36
+ ONET_NEGATIVE_ANNO_FILENAME = "neg_48.txt"
37
+ ONET_PART_ANNO_FILENAME = "part_48.txt"
38
+ ONET_LANDMARK_ANNO_FILENAME = "landmark_48.txt"
39
+
40
+ PNET_TRAIN_IMGLIST_FILENAME = "imglist_anno_12.txt"
41
+ RNET_TRAIN_IMGLIST_FILENAME = "imglist_anno_24.txt"
42
+ ONET_TRAIN_IMGLIST_FILENAME = "imglist_anno_48.txt"
utils/dataloader.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+ import numpy as np
3
+ import os
4
+ import cv2
5
+ def convert_image_to_tensor(image):
6
+ """convert an image to pytorch tensor
7
+
8
+ Parameters:
9
+ ----------
10
+ image: numpy array , h * w * c
11
+
12
+ Returns:
13
+ -------
14
+ image_tensor: pytorch.FloatTensor, c * h * w
15
+ """
16
+ transform = transforms.ToTensor()
17
+
18
+ return transform(image)
19
+
20
+
21
+ def convert_chwTensor_to_hwcNumpy(tensor):
22
+ """convert a group images pytorch tensor(count * c * h * w) to numpy array images(count * h * w * c)
23
+ Parameters:
24
+ ----------
25
+ tensor: numpy array , count * c * h * w
26
+
27
+ Returns:
28
+ -------
29
+ numpy array images: count * h * w * c
30
+ """
31
+ return np.transpose(tensor.detach().numpy(), (0,2,3,1))
32
+
33
+ class ImageDB(object):
34
+ def __init__(self, image_annotation_file, prefix_path='', mode='train'):
35
+ self.prefix_path = prefix_path
36
+ self.image_annotation_file = image_annotation_file
37
+ self.classes = ['__background__', 'face']
38
+ self.num_classes = 2
39
+ self.image_set_index = self.load_image_set_index()
40
+ self.num_images = len(self.image_set_index)
41
+ self.mode = mode
42
+
43
+
44
+ def load_image_set_index(self):
45
+ """Get image index
46
+
47
+ Parameters:
48
+ ----------
49
+ Returns:
50
+ -------
51
+ image_set_index: str
52
+ relative path of image
53
+ """
54
+ assert os.path.exists(self.image_annotation_file), 'Path does not exist: {}'.format(self.image_annotation_file)
55
+ with open(self.image_annotation_file, 'r') as f:
56
+ image_set_index = [x.strip().split(' ')[0] for x in f.readlines()]
57
+ return image_set_index
58
+
59
+
60
+ def load_imdb(self):
61
+ """Get and save ground truth image database
62
+
63
+ Parameters:
64
+ ----------
65
+ Returns:
66
+ -------
67
+ gt_imdb: dict
68
+ image database with annotations
69
+ """
70
+ gt_imdb = self.load_annotations()
71
+ return gt_imdb
72
+
73
+
74
+ def real_image_path(self, index):
75
+ """Given image index, return full path
76
+
77
+ Parameters:
78
+ ----------
79
+ index: str
80
+ relative path of image
81
+ Returns:
82
+ -------
83
+ image_file: str
84
+ full path of image
85
+ """
86
+
87
+ index = index.replace("\\", "/")
88
+
89
+ if not os.path.exists(index):
90
+ image_file = os.path.join(self.prefix_path, index)
91
+ else:
92
+ image_file=index
93
+ if not image_file.endswith('.jpg'):
94
+ image_file = image_file + '.jpg'
95
+ assert os.path.exists(image_file), 'Path does not exist: {}'.format(image_file)
96
+ return image_file
97
+
98
+
99
+ def load_annotations(self,annotion_type=1):
100
+ """Load annotations
101
+
102
+ Parameters:
103
+ ----------
104
+ annotion_type: int
105
+ 0:dsadsa
106
+ 1:dsadsa
107
+ Returns:
108
+ -------
109
+ imdb: dict
110
+ image database with annotations
111
+ """
112
+
113
+ assert os.path.exists(self.image_annotation_file), 'annotations not found at {}'.format(self.image_annotation_file)
114
+ with open(self.image_annotation_file, 'r') as f:
115
+ annotations = f.readlines()
116
+
117
+ imdb = []
118
+ for i in range(self.num_images):
119
+ annotation = annotations[i].strip().split(' ')
120
+ index = annotation[0]
121
+ im_path = self.real_image_path(index)
122
+ imdb_ = dict()
123
+ imdb_['image'] = im_path
124
+
125
+ if self.mode == 'test':
126
+ pass
127
+ else:
128
+ label = annotation[1]
129
+ imdb_['label'] = int(label)
130
+ imdb_['flipped'] = False
131
+ imdb_['bbox_target'] = np.zeros((4,))
132
+ imdb_['landmark_target'] = np.zeros((10,))
133
+ if len(annotation[2:])==4:
134
+ bbox_target = annotation[2:6]
135
+ imdb_['bbox_target'] = np.array(bbox_target).astype(float)
136
+ if len(annotation[2:])==14:
137
+ bbox_target = annotation[2:6]
138
+ imdb_['bbox_target'] = np.array(bbox_target).astype(float)
139
+ landmark = annotation[6:]
140
+ imdb_['landmark_target'] = np.array(landmark).astype(float)
141
+ imdb.append(imdb_)
142
+
143
+ return imdb
144
+
145
+
146
+ def append_flipped_images(self, imdb):
147
+ """append flipped images to imdb
148
+
149
+ Parameters:
150
+ ----------
151
+ imdb: imdb
152
+ image database
153
+ Returns:
154
+ -------
155
+ imdb: dict
156
+ image database with flipped image annotations added
157
+ """
158
+ print('append flipped images to imdb', len(imdb))
159
+ for i in range(len(imdb)):
160
+ imdb_ = imdb[i]
161
+ m_bbox = imdb_['bbox_target'].copy()
162
+ m_bbox[0], m_bbox[2] = -m_bbox[2], -m_bbox[0]
163
+
164
+ landmark_ = imdb_['landmark_target'].copy()
165
+ landmark_ = landmark_.reshape((5, 2))
166
+ landmark_ = np.asarray([(1 - x, y) for (x, y) in landmark_])
167
+ landmark_[[0, 1]] = landmark_[[1, 0]]
168
+ landmark_[[3, 4]] = landmark_[[4, 3]]
169
+
170
+ item = {'image': imdb_['image'],
171
+ 'label': imdb_['label'],
172
+ 'bbox_target': m_bbox,
173
+ 'landmark_target': landmark_.reshape((10)),
174
+ 'flipped': True}
175
+
176
+ imdb.append(item)
177
+ self.image_set_index *= 2
178
+ return imdb
179
+
180
+
181
+
182
+
183
+
184
+ class TrainImageReader:
185
+ def __init__(self, imdb, im_size, batch_size=128, shuffle=False):
186
+
187
+ self.imdb = imdb
188
+ self.batch_size = batch_size
189
+ self.im_size = im_size
190
+ self.shuffle = shuffle
191
+
192
+ self.cur = 0
193
+ self.size = len(imdb)
194
+ self.index = np.arange(self.size)
195
+ self.num_classes = 2
196
+
197
+ self.batch = None
198
+ self.data = None
199
+ self.label = None
200
+
201
+ self.label_names= ['label', 'bbox_target', 'landmark_target']
202
+ self.reset()
203
+ self.get_batch()
204
+
205
+ def reset(self):
206
+ self.cur = 0
207
+ if self.shuffle:
208
+ np.random.shuffle(self.index)
209
+
210
+ def iter_next(self):
211
+ return self.cur + self.batch_size <= self.size
212
+
213
+ def __iter__(self):
214
+ return self
215
+
216
+ def __next__(self):
217
+ return self.next()
218
+
219
+ def next(self):
220
+ if self.iter_next():
221
+ self.get_batch()
222
+ self.cur += self.batch_size
223
+ return self.data,self.label
224
+ else:
225
+ raise StopIteration
226
+
227
+ def getindex(self):
228
+ return self.cur / self.batch_size
229
+
230
+ def getpad(self):
231
+ if self.cur + self.batch_size > self.size:
232
+ return self.cur + self.batch_size - self.size
233
+ else:
234
+ return 0
235
+
236
+ def get_batch(self):
237
+ cur_from = self.cur
238
+ cur_to = min(cur_from + self.batch_size, self.size)
239
+ imdb = [self.imdb[self.index[i]] for i in range(cur_from, cur_to)]
240
+ data, label = get_minibatch(imdb)
241
+ self.data = data['data']
242
+ self.label = [label[name] for name in self.label_names]
243
+
244
+
245
+
246
+ class TestImageLoader:
247
+ def __init__(self, imdb, batch_size=1, shuffle=False):
248
+ self.imdb = imdb
249
+ self.batch_size = batch_size
250
+ self.shuffle = shuffle
251
+ self.size = len(imdb)
252
+ self.index = np.arange(self.size)
253
+
254
+ self.cur = 0
255
+ self.data = None
256
+ self.label = None
257
+
258
+ self.reset()
259
+ self.get_batch()
260
+
261
+ def reset(self):
262
+ self.cur = 0
263
+ if self.shuffle:
264
+ np.random.shuffle(self.index)
265
+
266
+ def iter_next(self):
267
+ return self.cur + self.batch_size <= self.size
268
+
269
+ def __iter__(self):
270
+ return self
271
+
272
+ def __next__(self):
273
+ return self.next()
274
+
275
+ def next(self):
276
+ if self.iter_next():
277
+ self.get_batch()
278
+ self.cur += self.batch_size
279
+ return self.data
280
+ else:
281
+ raise StopIteration
282
+
283
+ def getindex(self):
284
+ return self.cur / self.batch_size
285
+
286
+ def getpad(self):
287
+ if self.cur + self.batch_size > self.size:
288
+ return self.cur + self.batch_size - self.size
289
+ else:
290
+ return 0
291
+
292
+ def get_batch(self):
293
+ cur_from = self.cur
294
+ cur_to = min(cur_from + self.batch_size, self.size)
295
+ imdb = [self.imdb[self.index[i]] for i in range(cur_from, cur_to)]
296
+ data= get_testbatch(imdb)
297
+ self.data=data['data']
298
+
299
+
300
+
301
+
302
+ def get_minibatch(imdb):
303
+
304
+ # im_size: 12, 24 or 48
305
+ num_images = len(imdb)
306
+ processed_ims = list()
307
+ cls_label = list()
308
+ bbox_reg_target = list()
309
+ landmark_reg_target = list()
310
+
311
+ for i in range(num_images):
312
+ im = cv2.imread(imdb[i]['image'])
313
+
314
+ if imdb[i]['flipped']:
315
+ im = im[:, ::-1, :]
316
+
317
+ cls = imdb[i]['label']
318
+ bbox_target = imdb[i]['bbox_target']
319
+ landmark = imdb[i]['landmark_target']
320
+
321
+ processed_ims.append(im)
322
+ cls_label.append(cls)
323
+ bbox_reg_target.append(bbox_target)
324
+ landmark_reg_target.append(landmark)
325
+
326
+ im_array = np.asarray(processed_ims)
327
+
328
+ label_array = np.array(cls_label)
329
+
330
+ bbox_target_array = np.vstack(bbox_reg_target)
331
+
332
+ landmark_target_array = np.vstack(landmark_reg_target)
333
+
334
+ data = {'data': im_array}
335
+ label = {'label': label_array,
336
+ 'bbox_target': bbox_target_array,
337
+ 'landmark_target': landmark_target_array
338
+ }
339
+
340
+ return data, label
341
+
342
+
343
+ def get_testbatch(imdb):
344
+ assert len(imdb) == 1, "Single batch only"
345
+ im = cv2.imread(imdb[0]['image'])
346
+ data = {'data': im}
347
+ return data
utils/detect.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ from utils.models import PNet,RNet,ONet
6
+ import utils.tool as utils
7
+ import utils.dataloader as image_tools
8
+
9
+
10
+ def create_mtcnn_net(p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True):
11
+
12
+ pnet, rnet, onet = None, None, None
13
+
14
+ if p_model_path is not None:
15
+ pnet = PNet(use_cuda=use_cuda)
16
+ if(use_cuda):
17
+ print('p_model_path:{0}'.format(p_model_path))
18
+ pnet.load_state_dict(torch.load(p_model_path))
19
+ pnet.cuda()
20
+ else:
21
+ # forcing all GPU tensors to be in CPU while loading
22
+ #pnet.load_state_dict(torch.load(p_model_path, map_location=lambda storage, loc: storage))
23
+ pnet.load_state_dict(torch.load(p_model_path, map_location='cpu'))
24
+ pnet.eval()
25
+
26
+ if r_model_path is not None:
27
+ rnet = RNet(use_cuda=use_cuda)
28
+ if (use_cuda):
29
+ print('r_model_path:{0}'.format(r_model_path))
30
+ rnet.load_state_dict(torch.load(r_model_path))
31
+ rnet.cuda()
32
+ else:
33
+ rnet.load_state_dict(torch.load(r_model_path, map_location=lambda storage, loc: storage))
34
+ rnet.eval()
35
+
36
+ if o_model_path is not None:
37
+ onet = ONet(use_cuda=use_cuda)
38
+ if (use_cuda):
39
+ print('o_model_path:{0}'.format(o_model_path))
40
+ onet.load_state_dict(torch.load(o_model_path))
41
+ onet.cuda()
42
+ else:
43
+ onet.load_state_dict(torch.load(o_model_path, map_location=lambda storage, loc: storage))
44
+ onet.eval()
45
+
46
+ return pnet,rnet,onet
47
+
48
+
49
+
50
+
51
+ class MtcnnDetector(object):
52
+ """
53
+ P,R,O net face detection and landmarks align
54
+ """
55
+ def __init__(self,
56
+ pnet = None,
57
+ rnet = None,
58
+ onet = None,
59
+ min_face_size=12,
60
+ stride=2,
61
+ threshold=[0.6, 0.7, 0.7],
62
+ #threshold=[0.1, 0.1, 0.1],
63
+ scale_factor=0.709,
64
+ ):
65
+
66
+ self.pnet_detector = pnet
67
+ self.rnet_detector = rnet
68
+ self.onet_detector = onet
69
+ self.min_face_size = min_face_size
70
+ self.stride=stride
71
+ self.thresh = threshold
72
+ self.scale_factor = scale_factor
73
+
74
+
75
+ def unique_image_format(self,im):
76
+ if not isinstance(im,np.ndarray):
77
+ if im.mode == 'I':
78
+ im = np.array(im, np.int32, copy=False)
79
+ elif im.mode == 'I;16':
80
+ im = np.array(im, np.int16, copy=False)
81
+ else:
82
+ im = np.asarray(im)
83
+ return im
84
+
85
+ def square_bbox(self, bbox):
86
+ """
87
+ convert bbox to square
88
+ Parameters:
89
+ ----------
90
+ bbox: numpy array , shape n x m
91
+ input bbox
92
+ Returns:
93
+ -------
94
+ a square bbox
95
+ """
96
+ square_bbox = bbox.copy()
97
+
98
+ # x2 - x1
99
+ # y2 - y1
100
+ h = bbox[:, 3] - bbox[:, 1] + 1
101
+ w = bbox[:, 2] - bbox[:, 0] + 1
102
+ l = np.maximum(h,w)
103
+ # x1 = x1 + w*0.5 - l*0.5
104
+ # y1 = y1 + h*0.5 - l*0.5
105
+ square_bbox[:, 0] = bbox[:, 0] + w*0.5 - l*0.5
106
+ square_bbox[:, 1] = bbox[:, 1] + h*0.5 - l*0.5
107
+
108
+ # x2 = x1 + l - 1
109
+ # y2 = y1 + l - 1
110
+ square_bbox[:, 2] = square_bbox[:, 0] + l - 1
111
+ square_bbox[:, 3] = square_bbox[:, 1] + l - 1
112
+ return square_bbox
113
+
114
+
115
+ def generate_bounding_box(self, map, reg, scale, threshold):
116
+ """
117
+ generate bbox from feature map
118
+ Parameters:
119
+ ----------
120
+ map: numpy array , n x m x 1
121
+ detect score for each position
122
+ reg: numpy array , n x m x 4
123
+ bbox
124
+ scale: float number
125
+ scale of this detection
126
+ threshold: float number
127
+ detect threshold
128
+ Returns:
129
+ -------
130
+ bbox array
131
+ """
132
+ stride = 2
133
+ cellsize = 12 # receptive field
134
+
135
+ t_index = np.where(map[:,:,0] > threshold)
136
+ # print('shape of t_index:{0}'.format(len(t_index)))
137
+ # print('t_index{0}'.format(t_index))
138
+ # time.sleep(5)
139
+
140
+ # find nothing
141
+ if t_index[0].size == 0:
142
+ return np.array([])
143
+
144
+ # reg = (1, n, m, 4)
145
+ # choose bounding box whose socre are larger than threshold
146
+ dx1, dy1, dx2, dy2 = [reg[0, t_index[0], t_index[1], i] for i in range(4)]
147
+ #print(dx1.shape)
148
+ #exit()
149
+ # time.sleep(5)
150
+ reg = np.array([dx1, dy1, dx2, dy2])
151
+ #print('shape of reg{0}'.format(reg.shape))
152
+ #exit()
153
+
154
+ # lefteye_dx, lefteye_dy, righteye_dx, righteye_dy, nose_dx, nose_dy, \
155
+ # leftmouth_dx, leftmouth_dy, rightmouth_dx, rightmouth_dy = [landmarks[0, t_index[0], t_index[1], i] for i in range(10)]
156
+ #
157
+ # landmarks = np.array([lefteye_dx, lefteye_dy, righteye_dx, righteye_dy, nose_dx, nose_dy, leftmouth_dx, leftmouth_dy, rightmouth_dx, rightmouth_dy])
158
+
159
+ # abtain score of classification which larger than threshold
160
+ # t_index[0]: choose the first column of t_index
161
+ # t_index[1]: choose the second column of t_index
162
+ score = map[t_index[0], t_index[1], 0]
163
+ # hence t_index[1] means column, t_index[1] is the value of x
164
+ # hence t_index[0] means row, t_index[0] is the value of y
165
+ boundingbox = np.vstack([np.round((stride * t_index[1]) / scale), # x1 of prediction box in original image
166
+ np.round((stride * t_index[0]) / scale), # y1 of prediction box in original image
167
+ np.round((stride * t_index[1] + cellsize) / scale), # x2 of prediction box in original image
168
+ np.round((stride * t_index[0] + cellsize) / scale), # y2 of prediction box in original image
169
+ # reconstruct the box in original image
170
+ score,
171
+ reg,
172
+ # landmarks
173
+ ])
174
+
175
+ return boundingbox.T
176
+
177
+
178
+ def resize_image(self, img, scale):
179
+ """
180
+ resize image and transform dimention to [batchsize, channel, height, width]
181
+ Parameters:
182
+ ----------
183
+ img: numpy array , height x width x channel
184
+ input image, channels in BGR order here
185
+ scale: float number
186
+ scale factor of resize operation
187
+ Returns:
188
+ -------
189
+ transformed image tensor , 1 x channel x height x width
190
+ """
191
+ height, width, channels = img.shape
192
+ new_height = int(height * scale) # resized new height
193
+ new_width = int(width * scale) # resized new width
194
+ new_dim = (new_width, new_height)
195
+ img_resized = cv2.resize(img, new_dim, interpolation=cv2.INTER_LINEAR) # resized image
196
+ return img_resized
197
+
198
+
199
+ def pad(self, bboxes, w, h):
200
+ """
201
+ pad the the boxes
202
+ Parameters:
203
+ ----------
204
+ bboxes: numpy array, n x 5
205
+ input bboxes
206
+ w: float number
207
+ width of the input image
208
+ h: float number
209
+ height of the input image
210
+ Returns :
211
+ ------
212
+ dy, dx : numpy array, n x 1
213
+ start point of the bbox in target image
214
+ edy, edx : numpy array, n x 1
215
+ end point of the bbox in target image
216
+ y, x : numpy array, n x 1
217
+ start point of the bbox in original image
218
+ ex, ex : numpy array, n x 1
219
+ end point of the bbox in original image
220
+ tmph, tmpw: numpy array, n x 1
221
+ height and width of the bbox
222
+ """
223
+ # width and height
224
+ tmpw = (bboxes[:, 2] - bboxes[:, 0] + 1).astype(np.int32)
225
+ tmph = (bboxes[:, 3] - bboxes[:, 1] + 1).astype(np.int32)
226
+ numbox = bboxes.shape[0]
227
+
228
+ dx = np.zeros((numbox, ))
229
+ dy = np.zeros((numbox, ))
230
+ edx, edy = tmpw.copy()-1, tmph.copy()-1
231
+ # x, y: start point of the bbox in original image
232
+ # ex, ey: end point of the bbox in original image
233
+ x, y, ex, ey = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]
234
+
235
+ tmp_index = np.where(ex > w-1)
236
+ edx[tmp_index] = tmpw[tmp_index] + w - 2 - ex[tmp_index]
237
+ ex[tmp_index] = w - 1
238
+
239
+ tmp_index = np.where(ey > h-1)
240
+ edy[tmp_index] = tmph[tmp_index] + h - 2 - ey[tmp_index]
241
+ ey[tmp_index] = h - 1
242
+
243
+ tmp_index = np.where(x < 0)
244
+ dx[tmp_index] = 0 - x[tmp_index]
245
+ x[tmp_index] = 0
246
+
247
+ tmp_index = np.where(y < 0)
248
+ dy[tmp_index] = 0 - y[tmp_index]
249
+ y[tmp_index] = 0
250
+
251
+ return_list = [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph]
252
+ return_list = [item.astype(np.int32) for item in return_list]
253
+
254
+ return return_list
255
+
256
+
257
+ def detect_pnet(self, im):
258
+ """Get face candidates through pnet
259
+
260
+ Parameters:
261
+ ----------
262
+ im: numpy array
263
+ input image array
264
+ one batch
265
+
266
+ Returns:
267
+ -------
268
+ boxes: numpy array
269
+ detected boxes before calibration
270
+ boxes_align: numpy array
271
+ boxes after calibration
272
+ """
273
+
274
+ # im = self.unique_image_format(im)
275
+
276
+ # original wider face data
277
+ h, w, c = im.shape
278
+ net_size = 12
279
+
280
+ current_scale = float(net_size) / self.min_face_size # find initial scale
281
+ #print('imgshape:{0}, current_scale:{1}'.format(im.shape, current_scale))
282
+ im_resized = self.resize_image(im, current_scale) # scale = 1.0
283
+ current_height, current_width, _ = im_resized.shape
284
+ # fcn
285
+ all_boxes = list()
286
+ while min(current_height, current_width) > net_size:
287
+ #print('current:',current_height, current_width)
288
+ feed_imgs = []
289
+ image_tensor = image_tools.convert_image_to_tensor(im_resized)
290
+ feed_imgs.append(image_tensor)
291
+ feed_imgs = torch.stack(feed_imgs)
292
+
293
+ feed_imgs.requires_grad = True
294
+
295
+ if self.pnet_detector.use_cuda:
296
+ feed_imgs = feed_imgs.cuda()
297
+
298
+ # self.pnet_detector is a trained pnet torch model
299
+
300
+ # receptive field is 12×12
301
+ # 12×12 --> score
302
+ # 12×12 --> bounding box
303
+ cls_map, reg = self.pnet_detector(feed_imgs)
304
+
305
+ cls_map_np = image_tools.convert_chwTensor_to_hwcNumpy(cls_map.cpu())
306
+ reg_np = image_tools.convert_chwTensor_to_hwcNumpy(reg.cpu())
307
+ # print(cls_map_np.shape, reg_np.shape) # cls_map_np = (1, n, m, 1) reg_np.shape = (1, n, m 4)
308
+ # time.sleep(5)
309
+ # landmark_np = image_tools.convert_chwTensor_to_hwcNumpy(landmark.cpu())
310
+
311
+ # self.threshold[0] = 0.6
312
+ # print(cls_map_np[0,:,:].shape)
313
+ # time.sleep(4)
314
+
315
+ # boxes = [x1, y1, x2, y2, score, reg]
316
+ boxes = self.generate_bounding_box(cls_map_np[ 0, :, :], reg_np, current_scale, self.thresh[0])
317
+ #cv2.rectangle(im,(300,100),(400,200),color=(0,0,0))
318
+ #cv2.rectangle(im,(400,200),(500,300),color=(0,0,0))
319
+
320
+ # generate pyramid images
321
+ current_scale *= self.scale_factor # self.scale_factor = 0.709
322
+ im_resized = self.resize_image(im, current_scale)
323
+ current_height, current_width, _ = im_resized.shape
324
+
325
+ if boxes.size == 0:
326
+ continue
327
+
328
+ # non-maximum suppresion
329
+ keep = utils.nms(boxes[:, :5], 0.5, 'Union')
330
+ boxes = boxes[keep]
331
+ all_boxes.append(boxes)
332
+
333
+ """ img = im.copy()
334
+ bw = boxes[:,2]-boxes[:,0]
335
+ bh = boxes[:,3]-boxes[:,1]
336
+ for i in range(boxes.shape[0]):
337
+ p1=(int(boxes[i][0]+boxes[i][5]*bw[i]),int(boxes[i][1]+boxes[i][6]*bh[i]))
338
+ p2=(int(boxes[i][2]+boxes[i][7]*bw[i]),int(boxes[i][3]+boxes[i][8]*bh[i]))
339
+ cv2.rectangle(img,p1,p2,color=(0,0,0))
340
+ cv2.imshow('ss',img)
341
+ cv2.waitKey(0)
342
+ #ii+=1
343
+ exit() """
344
+
345
+ if len(all_boxes) == 0:
346
+ return None, None
347
+ all_boxes = np.vstack(all_boxes)
348
+ # print("shape of all boxes {0}".format(all_boxes.shape))
349
+ # time.sleep(5)
350
+
351
+ # merge the detection from first stage
352
+ keep = utils.nms(all_boxes[:, 0:5], 0.7, 'Union')
353
+ all_boxes = all_boxes[keep]
354
+ # boxes = all_boxes[:, :5]
355
+
356
+ # x2 - x1
357
+ # y2 - y1
358
+ bw = all_boxes[:, 2] - all_boxes[:, 0] + 1
359
+ bh = all_boxes[:, 3] - all_boxes[:, 1] + 1
360
+
361
+ # landmark_keep = all_boxes[:, 9:].reshape((5,2))
362
+
363
+
364
+ boxes = np.vstack([all_boxes[:,0],
365
+ all_boxes[:,1],
366
+ all_boxes[:,2],
367
+ all_boxes[:,3],
368
+ all_boxes[:,4],
369
+ # all_boxes[:, 0] + all_boxes[:, 9] * bw,
370
+ # all_boxes[:, 1] + all_boxes[:,10] * bh,
371
+ # all_boxes[:, 0] + all_boxes[:, 11] * bw,
372
+ # all_boxes[:, 1] + all_boxes[:, 12] * bh,
373
+ # all_boxes[:, 0] + all_boxes[:, 13] * bw,
374
+ # all_boxes[:, 1] + all_boxes[:, 14] * bh,
375
+ # all_boxes[:, 0] + all_boxes[:, 15] * bw,
376
+ # all_boxes[:, 1] + all_boxes[:, 16] * bh,
377
+ # all_boxes[:, 0] + all_boxes[:, 17] * bw,
378
+ # all_boxes[:, 1] + all_boxes[:, 18] * bh
379
+ ])
380
+
381
+ boxes = boxes.T
382
+
383
+ # boxes = boxes = [x1, y1, x2, y2, score, reg] reg= [px1, py1, px2, py2] (in prediction)
384
+ align_topx = all_boxes[:, 0] + all_boxes[:, 5] * bw
385
+ align_topy = all_boxes[:, 1] + all_boxes[:, 6] * bh
386
+ align_bottomx = all_boxes[:, 2] + all_boxes[:, 7] * bw
387
+ align_bottomy = all_boxes[:, 3] + all_boxes[:, 8] * bh
388
+
389
+ # refine the boxes
390
+ boxes_align = np.vstack([ align_topx,
391
+ align_topy,
392
+ align_bottomx,
393
+ align_bottomy,
394
+ all_boxes[:, 4],
395
+ # align_topx + all_boxes[:,9] * bw,
396
+ # align_topy + all_boxes[:,10] * bh,
397
+ # align_topx + all_boxes[:,11] * bw,
398
+ # align_topy + all_boxes[:,12] * bh,
399
+ # align_topx + all_boxes[:,13] * bw,
400
+ # align_topy + all_boxes[:,14] * bh,
401
+ # align_topx + all_boxes[:,15] * bw,
402
+ # align_topy + all_boxes[:,16] * bh,
403
+ # align_topx + all_boxes[:,17] * bw,
404
+ # align_topy + all_boxes[:,18] * bh,
405
+ ])
406
+ boxes_align = boxes_align.T
407
+
408
+ #remove invalid box
409
+ valindex = [True for _ in range(boxes_align.shape[0])]
410
+ for i in range(boxes_align.shape[0]):
411
+ if boxes_align[i][2]-boxes_align[i][0]<=3 or boxes_align[i][3]-boxes_align[i][1]<=3:
412
+ valindex[i]=False
413
+ #print('pnet has one smaller than 3')
414
+ else:
415
+ if boxes_align[i][2]<1 or boxes_align[i][0]>w-2 or boxes_align[i][3]<1 or boxes_align[i][1]>h-2:
416
+ valindex[i]=False
417
+ #print('pnet has one out')
418
+ boxes_align=boxes_align[valindex,:]
419
+ boxes = boxes[valindex,:]
420
+ return boxes, boxes_align
421
+
422
+ def detect_rnet(self, im, dets):
423
+ """Get face candidates using rnet
424
+
425
+ Parameters:
426
+ ----------
427
+ im: numpy array
428
+ input image array
429
+ dets: numpy array
430
+ detection results of pnet
431
+
432
+ Returns:
433
+ -------
434
+ boxes: numpy array
435
+ detected boxes before calibration
436
+ boxes_align: numpy array
437
+ boxes after calibration
438
+ """
439
+ # im: an input image
440
+ h, w, c = im.shape
441
+
442
+ if dets is None:
443
+ return None,None
444
+ if dets.shape[0]==0:
445
+ return None, None
446
+
447
+ # (705, 5) = [x1, y1, x2, y2, score, reg]
448
+ # print("pnet detection {0}".format(dets.shape))
449
+ # time.sleep(5)
450
+ detss = dets
451
+ # return square boxes
452
+ dets = self.square_bbox(dets)
453
+ detsss = dets
454
+ # rounds
455
+ dets[:, 0:4] = np.round(dets[:, 0:4])
456
+ [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h)
457
+ num_boxes = dets.shape[0]
458
+
459
+ '''
460
+ # helper for setting RNet batch size
461
+ batch_size = self.rnet_detector.batch_size
462
+ ratio = float(num_boxes) / batch_size
463
+ if ratio > 3 or ratio < 0.3:
464
+ print "You may need to reset RNet batch size if this info appears frequently, \
465
+ face candidates:%d, current batch_size:%d"%(num_boxes, batch_size)
466
+ '''
467
+
468
+ # cropped_ims_tensors = np.zeros((num_boxes, 3, 24, 24), dtype=np.float32)
469
+ cropped_ims_tensors = []
470
+ for i in range(num_boxes):
471
+ try:
472
+ tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8)
473
+ tmp[dy[i]:edy[i]+1, dx[i]:edx[i]+1, :] = im[y[i]:ey[i]+1, x[i]:ex[i]+1, :]
474
+ except:
475
+ print(dy[i],edy[i],dx[i],edx[i],y[i],ey[i],x[i],ex[i],tmpw[i],tmph[i])
476
+ print(dets[i])
477
+ print(detss[i])
478
+ print(detsss[i])
479
+ print(h,w)
480
+ exit()
481
+ crop_im = cv2.resize(tmp, (24, 24))
482
+ crop_im_tensor = image_tools.convert_image_to_tensor(crop_im)
483
+ # cropped_ims_tensors[i, :, :, :] = crop_im_tensor
484
+ cropped_ims_tensors.append(crop_im_tensor)
485
+ feed_imgs = torch.stack(cropped_ims_tensors)
486
+ feed_imgs.requires_grad = True
487
+
488
+
489
+ if self.rnet_detector.use_cuda:
490
+ feed_imgs = feed_imgs.cuda()
491
+
492
+ cls_map, reg = self.rnet_detector(feed_imgs)
493
+
494
+ cls_map = cls_map.cpu().data.numpy()
495
+ reg = reg.cpu().data.numpy()
496
+ # landmark = landmark.cpu().data.numpy()
497
+
498
+
499
+ keep_inds = np.where(cls_map > self.thresh[1])[0]
500
+
501
+ if len(keep_inds) > 0:
502
+ boxes = dets[keep_inds]
503
+ cls = cls_map[keep_inds]
504
+ reg = reg[keep_inds]
505
+ # landmark = landmark[keep_inds]
506
+ else:
507
+ return None, None
508
+ keep = utils.nms(boxes, 0.7)
509
+
510
+ if len(keep) == 0:
511
+ return None, None
512
+
513
+ keep_cls = cls[keep]
514
+ keep_boxes = boxes[keep]
515
+ keep_reg = reg[keep]
516
+ # keep_landmark = landmark[keep]
517
+
518
+
519
+ bw = keep_boxes[:, 2] - keep_boxes[:, 0] + 1
520
+ bh = keep_boxes[:, 3] - keep_boxes[:, 1] + 1
521
+
522
+
523
+ boxes = np.vstack([ keep_boxes[:,0],
524
+ keep_boxes[:,1],
525
+ keep_boxes[:,2],
526
+ keep_boxes[:,3],
527
+ keep_cls[:,0],
528
+ # keep_boxes[:,0] + keep_landmark[:, 0] * bw,
529
+ # keep_boxes[:,1] + keep_landmark[:, 1] * bh,
530
+ # keep_boxes[:,0] + keep_landmark[:, 2] * bw,
531
+ # keep_boxes[:,1] + keep_landmark[:, 3] * bh,
532
+ # keep_boxes[:,0] + keep_landmark[:, 4] * bw,
533
+ # keep_boxes[:,1] + keep_landmark[:, 5] * bh,
534
+ # keep_boxes[:,0] + keep_landmark[:, 6] * bw,
535
+ # keep_boxes[:,1] + keep_landmark[:, 7] * bh,
536
+ # keep_boxes[:,0] + keep_landmark[:, 8] * bw,
537
+ # keep_boxes[:,1] + keep_landmark[:, 9] * bh,
538
+ ])
539
+
540
+ align_topx = keep_boxes[:,0] + keep_reg[:,0] * bw
541
+ align_topy = keep_boxes[:,1] + keep_reg[:,1] * bh
542
+ align_bottomx = keep_boxes[:,2] + keep_reg[:,2] * bw
543
+ align_bottomy = keep_boxes[:,3] + keep_reg[:,3] * bh
544
+
545
+ boxes_align = np.vstack([align_topx,
546
+ align_topy,
547
+ align_bottomx,
548
+ align_bottomy,
549
+ keep_cls[:, 0],
550
+ # align_topx + keep_landmark[:, 0] * bw,
551
+ # align_topy + keep_landmark[:, 1] * bh,
552
+ # align_topx + keep_landmark[:, 2] * bw,
553
+ # align_topy + keep_landmark[:, 3] * bh,
554
+ # align_topx + keep_landmark[:, 4] * bw,
555
+ # align_topy + keep_landmark[:, 5] * bh,
556
+ # align_topx + keep_landmark[:, 6] * bw,
557
+ # align_topy + keep_landmark[:, 7] * bh,
558
+ # align_topx + keep_landmark[:, 8] * bw,
559
+ # align_topy + keep_landmark[:, 9] * bh,
560
+ ])
561
+
562
+ boxes = boxes.T
563
+ boxes_align = boxes_align.T
564
+
565
+ #remove invalid box
566
+ valindex = [True for _ in range(boxes_align.shape[0])]
567
+ for i in range(boxes_align.shape[0]):
568
+ if boxes_align[i][2]-boxes_align[i][0]<=3 or boxes_align[i][3]-boxes_align[i][1]<=3:
569
+ valindex[i]=False
570
+ print('rnet has one smaller than 3')
571
+ else:
572
+ if boxes_align[i][2]<1 or boxes_align[i][0]>w-2 or boxes_align[i][3]<1 or boxes_align[i][1]>h-2:
573
+ valindex[i]=False
574
+ print('rnet has one out')
575
+ boxes_align=boxes_align[valindex,:]
576
+ boxes = boxes[valindex,:]
577
+ """ img = im.copy()
578
+ for i in range(boxes_align.shape[0]):
579
+ p1=(int(boxes_align[i,0]),int(boxes_align[i,1]))
580
+ p2=(int(boxes_align[i,2]),int(boxes_align[i,3]))
581
+ cv2.rectangle(img,p1,p2,color=(0,0,0))
582
+ cv2.imshow('ss',img)
583
+ cv2.waitKey(0)
584
+ exit() """
585
+ return boxes, boxes_align
586
+
587
+ def detect_onet(self, im, dets):
588
+ """Get face candidates using onet
589
+
590
+ Parameters:
591
+ ----------
592
+ im: numpy array
593
+ input image array
594
+ dets: numpy array
595
+ detection results of rnet
596
+
597
+ Returns:
598
+ -------
599
+ boxes_align: numpy array
600
+ boxes after calibration
601
+ landmarks_align: numpy array
602
+ landmarks after calibration
603
+
604
+ """
605
+ h, w, c = im.shape
606
+
607
+ if dets is None:
608
+ return None, None
609
+ if dets.shape[0]==0:
610
+ return None, None
611
+
612
+ detss = dets
613
+ dets = self.square_bbox(dets)
614
+
615
+
616
+ dets[:, 0:4] = np.round(dets[:, 0:4])
617
+
618
+ [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h)
619
+ num_boxes = dets.shape[0]
620
+
621
+
622
+ # cropped_ims_tensors = np.zeros((num_boxes, 3, 24, 24), dtype=np.float32)
623
+ cropped_ims_tensors = []
624
+ for i in range(num_boxes):
625
+ try:
626
+ tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8)
627
+ # crop input image
628
+ tmp[dy[i]:edy[i] + 1, dx[i]:edx[i] + 1, :] = im[y[i]:ey[i] + 1, x[i]:ex[i] + 1, :]
629
+ except:
630
+ print(dy[i],edy[i],dx[i],edx[i],y[i],ey[i],x[i],ex[i],tmpw[i],tmph[i])
631
+ print(dets[i])
632
+ print(detss[i])
633
+ print(h,w)
634
+ crop_im = cv2.resize(tmp, (48, 48))
635
+ crop_im_tensor = image_tools.convert_image_to_tensor(crop_im)
636
+ # cropped_ims_tensors[i, :, :, :] = crop_im_tensor
637
+ cropped_ims_tensors.append(crop_im_tensor)
638
+ feed_imgs = torch.stack(cropped_ims_tensors)
639
+ feed_imgs.requires_grad = True
640
+
641
+ if self.rnet_detector.use_cuda:
642
+ feed_imgs = feed_imgs.cuda()
643
+
644
+ cls_map, reg, landmark = self.onet_detector(feed_imgs)
645
+
646
+ cls_map = cls_map.cpu().data.numpy()
647
+ reg = reg.cpu().data.numpy()
648
+ landmark = landmark.cpu().data.numpy()
649
+
650
+ keep_inds = np.where(cls_map > self.thresh[2])[0]
651
+
652
+ if len(keep_inds) > 0:
653
+ boxes = dets[keep_inds]
654
+ cls = cls_map[keep_inds]
655
+ reg = reg[keep_inds]
656
+ landmark = landmark[keep_inds]
657
+ else:
658
+ return None, None
659
+
660
+ keep = utils.nms(boxes, 0.7, mode="Minimum")
661
+
662
+ if len(keep) == 0:
663
+ return None, None
664
+
665
+ keep_cls = cls[keep]
666
+ keep_boxes = boxes[keep]
667
+ keep_reg = reg[keep]
668
+ keep_landmark = landmark[keep]
669
+
670
+ bw = keep_boxes[:, 2] - keep_boxes[:, 0] + 1
671
+ bh = keep_boxes[:, 3] - keep_boxes[:, 1] + 1
672
+
673
+
674
+ align_topx = keep_boxes[:, 0] + keep_reg[:, 0] * bw
675
+ align_topy = keep_boxes[:, 1] + keep_reg[:, 1] * bh
676
+ align_bottomx = keep_boxes[:, 2] + keep_reg[:, 2] * bw
677
+ align_bottomy = keep_boxes[:, 3] + keep_reg[:, 3] * bh
678
+
679
+ align_landmark_topx = keep_boxes[:, 0]
680
+ align_landmark_topy = keep_boxes[:, 1]
681
+
682
+
683
+
684
+
685
+ boxes_align = np.vstack([align_topx,
686
+ align_topy,
687
+ align_bottomx,
688
+ align_bottomy,
689
+ keep_cls[:, 0],
690
+ # align_topx + keep_landmark[:, 0] * bw,
691
+ # align_topy + keep_landmark[:, 1] * bh,
692
+ # align_topx + keep_landmark[:, 2] * bw,
693
+ # align_topy + keep_landmark[:, 3] * bh,
694
+ # align_topx + keep_landmark[:, 4] * bw,
695
+ # align_topy + keep_landmark[:, 5] * bh,
696
+ # align_topx + keep_landmark[:, 6] * bw,
697
+ # align_topy + keep_landmark[:, 7] * bh,
698
+ # align_topx + keep_landmark[:, 8] * bw,
699
+ # align_topy + keep_landmark[:, 9] * bh,
700
+ ])
701
+
702
+ boxes_align = boxes_align.T
703
+
704
+ landmark = np.vstack([
705
+ align_landmark_topx + keep_landmark[:, 0] * bw,
706
+ align_landmark_topy + keep_landmark[:, 1] * bh,
707
+ align_landmark_topx + keep_landmark[:, 2] * bw,
708
+ align_landmark_topy + keep_landmark[:, 3] * bh,
709
+ align_landmark_topx + keep_landmark[:, 4] * bw,
710
+ align_landmark_topy + keep_landmark[:, 5] * bh,
711
+ align_landmark_topx + keep_landmark[:, 6] * bw,
712
+ align_landmark_topy + keep_landmark[:, 7] * bh,
713
+ align_landmark_topx + keep_landmark[:, 8] * bw,
714
+ align_landmark_topy + keep_landmark[:, 9] * bh,
715
+ ])
716
+
717
+ landmark_align = landmark.T
718
+
719
+ return boxes_align, landmark_align
720
+
721
+
722
+ def detect_face(self,img):
723
+ """Detect face over image
724
+ """
725
+ boxes_align = np.array([])
726
+ landmark_align =np.array([])
727
+
728
+ t = time.time()
729
+
730
+ # pnet
731
+ if self.pnet_detector:
732
+ p_boxes, boxes_align = self.detect_pnet(img)
733
+ if boxes_align is None:
734
+ return np.array([]), np.array([])
735
+
736
+ t1 = time.time() - t
737
+ t = time.time()
738
+
739
+ # rnet
740
+ if self.rnet_detector:
741
+ r_boxes, boxes_align = self.detect_rnet(img, boxes_align)
742
+ if boxes_align is None:
743
+ return np.array([]), np.array([])
744
+
745
+ t2 = time.time() - t
746
+ t = time.time()
747
+
748
+ # onet
749
+ if self.onet_detector:
750
+ boxes_align, landmark_align = self.detect_onet(img, boxes_align)
751
+ if boxes_align is None:
752
+ return np.array([]), np.array([])
753
+
754
+ t3 = time.time() - t
755
+ t = time.time()
756
+ print("time cost " + '{:.3f}'.format(t1+t2+t3) + ' pnet {:.3f} rnet {:.3f} onet {:.3f}'.format(t1, t2, t3))
757
+
758
+ return p_boxes,r_boxes,boxes_align, landmark_align
utils/models.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ '''模型使用老师提供的示例代码,仅修改了三处版本改动'''
5
+
6
+ def weights_init(m):
7
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
8
+ nn.init.xavier_uniform_(m.weight.data)
9
+ nn.init.constant_(m.bias, 0.1)
10
+
11
+
12
+
13
+ class LossFn:
14
+ def __init__(self, cls_factor=1, box_factor=1, landmark_factor=1):
15
+ # loss function
16
+ self.cls_factor = cls_factor
17
+ self.box_factor = box_factor
18
+ self.land_factor = landmark_factor
19
+ self.loss_cls = nn.BCELoss() # binary cross entropy
20
+ self.loss_box = nn.MSELoss() # mean square error
21
+ self.loss_landmark = nn.MSELoss()
22
+
23
+
24
+ def cls_loss(self,gt_label,pred_label):
25
+ pred_label = torch.squeeze(pred_label)
26
+ gt_label = torch.squeeze(gt_label)
27
+ # get the mask element which >= 0, only 0 and 1 can effect the detection loss
28
+ mask = torch.ge(gt_label,0)
29
+ valid_gt_label = torch.masked_select(gt_label,mask)
30
+ valid_pred_label = torch.masked_select(pred_label,mask)
31
+ return self.loss_cls(valid_pred_label,valid_gt_label)*self.cls_factor
32
+
33
+
34
+ def box_loss(self,gt_label,gt_offset,pred_offset):
35
+ pred_offset = torch.squeeze(pred_offset)
36
+ gt_offset = torch.squeeze(gt_offset)
37
+ gt_label = torch.squeeze(gt_label)
38
+
39
+ #get the mask element which != 0
40
+ unmask = torch.eq(gt_label,0)
41
+ mask = torch.eq(unmask,0)
42
+ #convert mask to dim index
43
+ chose_index = torch.nonzero(mask.data)
44
+ chose_index = torch.squeeze(chose_index)
45
+ #only valid element can effect the loss
46
+ valid_gt_offset = gt_offset[chose_index,:]
47
+ valid_pred_offset = pred_offset[chose_index,:]
48
+ return self.loss_box(valid_pred_offset,valid_gt_offset)*self.box_factor
49
+
50
+
51
+ def landmark_loss(self,gt_label,gt_landmark,pred_landmark):
52
+ pred_landmark = torch.squeeze(pred_landmark)
53
+ gt_landmark = torch.squeeze(gt_landmark)
54
+ gt_label = torch.squeeze(gt_label)
55
+ mask = torch.eq(gt_label,-2)
56
+
57
+ chose_index = torch.nonzero(mask.data)
58
+ chose_index = torch.squeeze(chose_index)
59
+
60
+ valid_gt_landmark = gt_landmark[chose_index, :]
61
+ valid_pred_landmark = pred_landmark[chose_index, :]
62
+ return self.loss_landmark(valid_pred_landmark,valid_gt_landmark)*self.land_factor
63
+
64
+
65
+
66
+
67
+
68
+ class PNet(nn.Module):
69
+ ''' PNet '''
70
+
71
+ def __init__(self, is_train=False, use_cuda=True):
72
+ super(PNet, self).__init__()
73
+ self.is_train = is_train
74
+ self.use_cuda = use_cuda
75
+
76
+ # backend
77
+ self.pre_layer = nn.Sequential(
78
+ nn.Conv2d(3, 10, kernel_size=3, stride=1), # conv1
79
+ nn.PReLU(), # PReLU1
80
+ nn.MaxPool2d(kernel_size=2, stride=2), # pool1
81
+ nn.Conv2d(10, 16, kernel_size=3, stride=1), # conv2
82
+ nn.PReLU(), # PReLU2
83
+ nn.Conv2d(16, 32, kernel_size=3, stride=1), # conv3
84
+ nn.PReLU() # PReLU3
85
+ )
86
+ # detection
87
+ self.conv4_1 = nn.Conv2d(32, 1, kernel_size=1, stride=1)
88
+ # bounding box regresion
89
+ self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1, stride=1)
90
+ # landmark localization
91
+ self.conv4_3 = nn.Conv2d(32, 10, kernel_size=1, stride=1)
92
+
93
+ # weight initiation with xavier
94
+ self.apply(weights_init)
95
+
96
+ def forward(self, x):
97
+ x = self.pre_layer(x)
98
+ label = torch.sigmoid(self.conv4_1(x))
99
+ offset = self.conv4_2(x)
100
+ # landmark = self.conv4_3(x)
101
+
102
+ if self.is_train is True:
103
+ # label_loss = LossUtil.label_loss(self.gt_label,torch.squeeze(label))
104
+ # bbox_loss = LossUtil.bbox_loss(self.gt_bbox,torch.squeeze(offset))
105
+ return label,offset
106
+ #landmark = self.conv4_3(x)
107
+ return label, offset
108
+
109
+
110
+
111
+
112
+
113
+ class RNet(nn.Module):
114
+ ''' RNet '''
115
+
116
+ def __init__(self,is_train=False, use_cuda=True):
117
+ super(RNet, self).__init__()
118
+ self.is_train = is_train
119
+ self.use_cuda = use_cuda
120
+ # backend
121
+ self.pre_layer = nn.Sequential(
122
+ nn.Conv2d(3, 28, kernel_size=3, stride=1), # conv1
123
+ nn.PReLU(), # prelu1
124
+ nn.MaxPool2d(kernel_size=3, stride=2), # pool1
125
+ nn.Conv2d(28, 48, kernel_size=3, stride=1), # conv2
126
+ nn.PReLU(), # prelu2
127
+ nn.MaxPool2d(kernel_size=3, stride=2), # pool2
128
+ nn.Conv2d(48, 64, kernel_size=2, stride=1), # conv3
129
+ nn.PReLU() # prelu3
130
+
131
+ )
132
+ self.conv4 = nn.Linear(64*2*2, 128) # conv4
133
+ self.prelu4 = nn.PReLU() # prelu4
134
+ # detection
135
+ self.conv5_1 = nn.Linear(128, 1)
136
+ # bounding box regression
137
+ self.conv5_2 = nn.Linear(128, 4)
138
+ # lanbmark localization
139
+ self.conv5_3 = nn.Linear(128, 10)
140
+ # weight initiation weih xavier
141
+ self.apply(weights_init)
142
+
143
+ def forward(self, x):
144
+ # backend
145
+ x = self.pre_layer(x)
146
+ x = x.view(x.size(0), -1)
147
+ x = self.conv4(x)
148
+ x = self.prelu4(x)
149
+ # detection
150
+ det = torch.sigmoid(self.conv5_1(x))
151
+ box = self.conv5_2(x)
152
+ # landmark = self.conv5_3(x)
153
+
154
+ if self.is_train is True:
155
+ return det, box
156
+ #landmard = self.conv5_3(x)
157
+ return det, box
158
+
159
+
160
+
161
+
162
+ class ONet(nn.Module):
163
+ ''' RNet '''
164
+
165
+ def __init__(self,is_train=False, use_cuda=True):
166
+ super(ONet, self).__init__()
167
+ self.is_train = is_train
168
+ self.use_cuda = use_cuda
169
+ # backend
170
+ self.pre_layer = nn.Sequential(
171
+ nn.Conv2d(3, 32, kernel_size=3, stride=1), # conv1
172
+ nn.PReLU(), # prelu1
173
+ nn.MaxPool2d(kernel_size=3, stride=2), # pool1
174
+ nn.Conv2d(32, 64, kernel_size=3, stride=1), # conv2
175
+ nn.PReLU(), # prelu2
176
+ nn.MaxPool2d(kernel_size=3, stride=2), # pool2
177
+ nn.Conv2d(64, 64, kernel_size=3, stride=1), # conv3
178
+ nn.PReLU(), # prelu3
179
+ nn.MaxPool2d(kernel_size=2,stride=2), # pool3
180
+ nn.Conv2d(64,128,kernel_size=2,stride=1), # conv4
181
+ nn.PReLU() # prelu4
182
+ )
183
+ self.conv5 = nn.Linear(128*2*2, 256) # conv5
184
+ self.prelu5 = nn.PReLU() # prelu5
185
+ # detection
186
+ self.conv6_1 = nn.Linear(256, 1)
187
+ # bounding box regression
188
+ self.conv6_2 = nn.Linear(256, 4)
189
+ # lanbmark localization
190
+ self.conv6_3 = nn.Linear(256, 10)
191
+ # weight initiation weih xavier
192
+ self.apply(weights_init)
193
+
194
+ def forward(self, x):
195
+ # backend
196
+ x = self.pre_layer(x)
197
+ x = x.view(x.size(0), -1)
198
+ x = self.conv5(x)
199
+ x = self.prelu5(x)
200
+ # detection
201
+ det = torch.sigmoid(self.conv6_1(x))
202
+ box = self.conv6_2(x)
203
+ landmark = self.conv6_3(x)
204
+ if self.is_train is True:
205
+ return det, box, landmark
206
+ #landmard = self.conv5_3(x)
207
+ return det, box, landmark
utils/tool.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import time
3
+
4
+ def IoU(box, boxes):
5
+ """Compute IoU between detect box and gt boxes
6
+
7
+ Parameters:
8
+ ----------
9
+ box: numpy array , shape (5, ): x1, y1, x2, y2, score
10
+ input box
11
+ boxes: numpy array, shape (n, 4): x1, y1, x2, y2
12
+ input ground truth boxes
13
+
14
+ Returns:
15
+ -------
16
+ ovr: numpy.array, shape (n, )
17
+ IoU
18
+ """
19
+ box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1)
20
+ area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)
21
+ xx1 = np.maximum(box[0], boxes[:, 0])
22
+ yy1 = np.maximum(box[1], boxes[:, 1])
23
+ xx2 = np.minimum(box[2], boxes[:, 2])
24
+ yy2 = np.minimum(box[3], boxes[:, 3])
25
+
26
+ # compute the width and height of the bounding box
27
+ w = np.maximum(0, xx2 - xx1 + 1)
28
+ h = np.maximum(0, yy2 - yy1 + 1)
29
+
30
+ inter = w * h
31
+ ovr = np.true_divide(inter,(box_area + area - inter))
32
+ #ovr = inter / (box_area + area - inter)
33
+ return ovr
34
+
35
+
36
+ def convert_to_square(bbox):
37
+ """Convert bbox to square
38
+
39
+ Parameters:
40
+ ----------
41
+ bbox: numpy array , shape n x 5
42
+ input bbox
43
+
44
+ Returns:
45
+ -------
46
+ square bbox
47
+ """
48
+ square_bbox = bbox.copy()
49
+
50
+ h = bbox[:, 3] - bbox[:, 1] + 1
51
+ w = bbox[:, 2] - bbox[:, 0] + 1
52
+ max_side = np.maximum(h,w)
53
+ square_bbox[:, 0] = bbox[:, 0] + w*0.5 - max_side*0.5
54
+ square_bbox[:, 1] = bbox[:, 1] + h*0.5 - max_side*0.5
55
+ square_bbox[:, 2] = square_bbox[:, 0] + max_side - 1
56
+ square_bbox[:, 3] = square_bbox[:, 1] + max_side - 1
57
+ return square_bbox
58
+
59
+ # non-maximum suppression: eleminates the box which have large interception with the box which have the largest score
60
+ def nms(dets, thresh, mode="Union"):
61
+ """
62
+ greedily select boxes with high confidence
63
+ keep boxes overlap <= thresh
64
+ rule out overlap > thresh
65
+ :param dets: [[x1, y1, x2, y2 score]]
66
+ :param thresh: retain overlap <= thresh
67
+ :return: indexes to keep
68
+ """
69
+ x1 = dets[:, 0]
70
+ y1 = dets[:, 1]
71
+ x2 = dets[:, 2]
72
+ y2 = dets[:, 3]
73
+ scores = dets[:, 4]
74
+
75
+ # shape of x1 = (454,), shape of scores = (454,)
76
+ # print("shape of x1 = {0}, shape of scores = {1}".format(x1.shape, scores.shape))
77
+ # time.sleep(5)
78
+
79
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
80
+ order = scores.argsort()[::-1] # argsort: ascending order then [::-1] reverse the order --> descending order
81
+ # print("shape of order {0}".format(order.size)) # (454,)
82
+ # time.sleep(5)
83
+
84
+ # eleminates the box which have large interception with the box which have the largest score in order
85
+ # matain the box with largest score and boxes don't have large interception with it
86
+ keep = []
87
+ while order.size > 0:
88
+ i = order[0]
89
+ keep.append(i)
90
+ xx1 = np.maximum(x1[i], x1[order[1:]])
91
+ yy1 = np.maximum(y1[i], y1[order[1:]])
92
+ xx2 = np.minimum(x2[i], x2[order[1:]])
93
+ yy2 = np.minimum(y2[i], y2[order[1:]])
94
+
95
+ w = np.maximum(0.0, xx2 - xx1 + 1)
96
+ h = np.maximum(0.0, yy2 - yy1 + 1)
97
+ inter = w * h
98
+
99
+ # cacaulate the IOU between box which have largest score with other boxes
100
+ if mode == "Union":
101
+ # area[i]: the area of largest score
102
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
103
+ elif mode == "Minimum":
104
+ ovr = inter / np.minimum(areas[i], areas[order[1:]])
105
+
106
+
107
+ inds = np.where(ovr <= thresh)[0]
108
+ order = order[inds + 1] # +1: eliminates the first element in order
109
+ # print(inds)
110
+ # print("shape of order {0}".format(order.shape)) # (454,)
111
+ # time.sleep(2)
112
+
113
+ return keep
114
+
115
+
116
+
117
+
utils/vision.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib.patches import Circle
2
+ import os
3
+ import sys
4
+ import matplotlib.pyplot as plt
5
+ import pylab
6
+ sys.path.append(os.getcwd())
7
+
8
+
9
+ def vis_face(im_array, dets, landmarks, face_size, save_name):
10
+ """Visualize detection results
11
+
12
+ Parameters:
13
+ ----------
14
+ im_array: numpy.ndarray, shape(1, c, h, w)
15
+ test image in rgb
16
+ dets1: numpy.ndarray([[x1 y1 x2 y2 score]])
17
+ detection results before calibration
18
+ dets2: numpy.ndarray([[x1 y1 x2 y2 score]])
19
+ detection results after calibration
20
+ thresh: float
21
+ boxes with scores > thresh will be drawn in red otherwise yellow
22
+
23
+ Returns:
24
+ -------
25
+ """
26
+
27
+ pylab.imshow(im_array)
28
+
29
+ for i in range(dets.shape[0]):
30
+ bbox = dets[i, :5]
31
+
32
+ rect = pylab.Rectangle((bbox[0], bbox[1]),
33
+ bbox[2] - bbox[0],
34
+ bbox[3] - bbox[1], fill=False,
35
+ edgecolor='red', linewidth=0.9)
36
+ score = bbox[4]
37
+ plt.gca().text(bbox[0], bbox[1] - 2,
38
+ '{:.5f}'.format(score),
39
+ bbox=dict(facecolor='red', alpha=0.5), fontsize=8, color='white')
40
+
41
+ pylab.gca().add_patch(rect)
42
+
43
+ if landmarks is not None:
44
+ for i in range(landmarks.shape[0]):
45
+ landmarks_one = landmarks[i, :]
46
+ landmarks_one = landmarks_one.reshape((5, 2))
47
+ for j in range(5):
48
+
49
+ cir1 = Circle(xy=(landmarks_one[j, 0], landmarks_one[j, 1]), radius=face_size/12, alpha=0.4, color="red")
50
+ pylab.gca().add_patch(cir1)
51
+
52
+ #pylab.savefig(save_name)
53
+ #只保存图片内容,不保存坐标轴
54
+ pylab.axis('off')
55
+ pylab.savefig(save_name, bbox_inches='tight', pad_inches=0.0)
56
+ pylab.show()
57
+ # 返回图片对象
58
+ return pylab