# -- coding: utf-8 -- # @Time : 2021/11/10 import numpy as np import cv2 from cv2box.utils.math import Normalize from cv2box import CVImage from .scrfd_insightface import SCRFD from face_detect.face_align_utils import norm_crop, apply_roi_func # https://github.com/deepinsight/insightface/tree/master/detection/scrfd SCRFD_MODEL_PATH = '../pretrain_models/' class FaceDetect5Landmarks: def __init__(self, mode='scrfd_500m', tracking=False): self.mode = mode self.tracking = tracking self.dis_list = [] self.last_bboxes_ = [] assert self.mode in ['scrfd', 'scrfd_500m', 'mtcnn'] self.bboxes = self.kpss = self.image = None if 'scrfd' in self.mode: if self.mode == 'scrfd_500m': scrfd_model_path = SCRFD_MODEL_PATH + 'scrfd_500m_bnkps_shape640x640.onnx' else: scrfd_model_path = SCRFD_MODEL_PATH + 'scrfd_10g_bnkps.onnx' self.det_model_scrfd = SCRFD(scrfd_model_path) self.det_model_scrfd.prepare(ctx_id=0, input_size=(640, 640)) def get_bboxes(self, image, nms_thresh=0.5, max_num=0, min_bbox_size=None): """ Args: image: RGB image path or Numpy array load by cv2 nms_thresh: max_num: min_bbox_size: Returns: """ self.image = CVImage(image).rgb() if self.tracking: if len(self.last_bboxes_) == 0: self.bboxes, self.kpss = self.det_model_scrfd.detect(image, thresh=nms_thresh, max_num=1, metric='default') self.last_bboxes_ = self.bboxes # return self.bboxes, self.kpss else: self.bboxes, self.kpss = self.det_model_scrfd.detect(image, thresh=nms_thresh, max_num=0, metric='default') self.bboxes, self.kpss = self.tracking_filter() else: if 'scrfd' in self.mode: self.bboxes, self.kpss = self.det_model_scrfd.detect(self.image, thresh=nms_thresh, max_num=max_num, metric='default') return self.bboxes, self.kpss def tracking_filter(self): for i in range(len(self.bboxes)): self.dis_list.append(np.linalg.norm(Normalize(self.bboxes[i]).np_norm() - Normalize(self.last_bboxes_[0]).np_norm())) if not self.dis_list: return [], [] best_index = np.argmin(np.array(self.dis_list)) self.dis_list = [] self.last_bboxes_ = [self.bboxes[best_index]] return self.last_bboxes_, [self.kpss[best_index]] def bboxes_filter(self, min_bbox_size): min_area = np.power(min_bbox_size, 2) area_list = (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) min_index = np.where(area_list < min_area) self.bboxes = np.delete(self.bboxes, min_index, axis=0) self.kpss = np.delete(self.kpss, min_index, axis=0) def get_single_face(self, crop_size, mode='mtcnn_512', apply_roi=False): """ Args: crop_size: mode: default mtcnn_512 arcface_512 arcface default_95 Returns: cv2 image """ assert mode in ['default', 'mtcnn_512', 'mtcnn_256', 'arcface_512', 'arcface', 'default_95'] if self.bboxes.shape[0] == 0: return None, None det_score = self.bboxes[..., 4] if self.tracking: best_index = np.argmax(np.array(self.dis_list)) kpss = None if self.kpss is not None: kpss = self.kpss[best_index] else: best_index = np.argmax(det_score) kpss = None if self.kpss is not None: kpss = self.kpss[best_index] if apply_roi: roi, roi_box, roi_kpss = apply_roi_func(self.image, self.bboxes[best_index], kpss) align_img, mat_rev = norm_crop(roi, roi_kpss, crop_size, mode=mode) align_img = cv2.cvtColor(align_img, cv2.COLOR_RGB2BGR) return align_img, mat_rev, roi_box else: align_img, M = norm_crop(self.image, kpss, crop_size, mode=mode) align_img = cv2.cvtColor(align_img, cv2.COLOR_RGB2BGR) return align_img, M def get_multi_face(self, crop_size, mode='mtcnn_512'): """ Args: crop_size: mode: default mtcnn_512 arcface_512 arcface Returns: """ if self.bboxes.shape[0] == 0: return None align_img_list = [] M_list = [] for i in range(self.bboxes.shape[0]): kps = None if self.kpss is not None: kps = self.kpss[i] align_img, M = norm_crop(self.image, kps, crop_size, mode=mode) align_img_list.append(align_img) M_list.append(M) return align_img_list, M_list def draw_face(self): for i_ in range(self.bboxes.shape[0]): bbox = self.bboxes[i_] x1, y1, x2, y2, score = bbox.astype(int) cv2.rectangle(self.image, (x1, y1), (x2, y2), (255, 0, 0), 2) if self.kpss is not None: kps = self.kpss[i_] for kp in kps: kp = kp.astype(int) cv2.circle(self.image, tuple(kp), 1, (0, 0, 255), 2) CVImage(self.image, image_format='cv2').show()