yuna0x0's picture
Init
906e212 unverified
from __future__ import annotations
import pathlib
import warnings
from typing import Optional, Union
import cv2
import mmcv
import numpy as np
import torch.nn as nn
from mmdet.apis import inference_detector, init_detector
from mmpose.apis import inference_top_down_pose_model, init_pose_model
from mmpose.datasets import DatasetInfo
class LandmarkDetector:
def __init__(
self,
landmark_detector_config_or_path: Union[mmcv.Config, str,
pathlib.Path],
landmark_detector_checkpoint_path: Union[str, pathlib.Path],
face_detector_config_or_path: Optional[Union[mmcv.Config, str,
pathlib.Path]] = None,
face_detector_checkpoint_path: Optional[Union[
str, pathlib.Path]] = None,
device: str = 'cuda:0',
flip_test: bool = True,
box_scale_factor: float = 1.1):
landmark_config = self._load_config(landmark_detector_config_or_path)
self.dataset_info = DatasetInfo(
landmark_config.dataset_info) # type: ignore
face_detector_config = self._load_config(face_detector_config_or_path)
self.landmark_detector = self._init_pose_model(
landmark_config, landmark_detector_checkpoint_path, device,
flip_test)
self.face_detector = self._init_face_detector(
face_detector_config, face_detector_checkpoint_path, device)
self.box_scale_factor = box_scale_factor
@staticmethod
def _load_config(
config_or_path: Optional[Union[mmcv.Config, str, pathlib.Path]]
) -> Optional[mmcv.Config]:
if config_or_path is None or isinstance(config_or_path, mmcv.Config):
return config_or_path
return mmcv.Config.fromfile(config_or_path)
@staticmethod
def _init_pose_model(config: mmcv.Config,
checkpoint_path: Union[str, pathlib.Path],
device: str, flip_test: bool) -> nn.Module:
if isinstance(checkpoint_path, pathlib.Path):
checkpoint_path = checkpoint_path.as_posix()
model = init_pose_model(config, checkpoint_path, device=device)
model.cfg.model.test_cfg.flip_test = flip_test
return model
@staticmethod
def _init_face_detector(config: Optional[mmcv.Config],
checkpoint_path: Optional[Union[str,
pathlib.Path]],
device: str) -> Optional[nn.Module]:
if config is not None:
if isinstance(checkpoint_path, pathlib.Path):
checkpoint_path = checkpoint_path.as_posix()
model = init_detector(config, checkpoint_path, device=device)
else:
model = None
return model
def _detect_faces(self, image: np.ndarray) -> list[np.ndarray]:
# predicted boxes using mmdet model have the format of
# [x0, y0, x1, y1, score]
boxes = inference_detector(self.face_detector, image)[0]
# scale boxes by `self.box_scale_factor`
boxes = self._update_pred_box(boxes)
return boxes
def _update_pred_box(self, pred_boxes: np.ndarray) -> list[np.ndarray]:
boxes = []
for pred_box in pred_boxes:
box = pred_box[:4]
size = box[2:] - box[:2] + 1
new_size = size * self.box_scale_factor
center = (box[:2] + box[2:]) / 2
tl = center - new_size / 2
br = tl + new_size
pred_box[:4] = np.concatenate([tl, br])
boxes.append(pred_box)
return boxes
def _detect_landmarks(
self, image: np.ndarray,
boxes: list[dict[str, np.ndarray]]) -> list[dict[str, np.ndarray]]:
preds, _ = inference_top_down_pose_model(
self.landmark_detector,
image,
boxes,
format='xyxy',
dataset_info=self.dataset_info,
return_heatmap=False)
return preds
@staticmethod
def _load_image(
image_or_path: Union[np.ndarray, str, pathlib.Path]) -> np.ndarray:
if isinstance(image_or_path, np.ndarray):
image = image_or_path
elif isinstance(image_or_path, str):
image = cv2.imread(image_or_path)
elif isinstance(image_or_path, pathlib.Path):
image = cv2.imread(image_or_path.as_posix())
else:
raise ValueError
return image
def __call__(
self,
image_or_path: Union[np.ndarray, str, pathlib.Path],
boxes: Optional[list[np.ndarray]] = None
) -> list[dict[str, np.ndarray]]:
"""Detect face landmarks.
Args:
image_or_path: An image with BGR channel order or an image path.
boxes: A list of bounding boxes for faces. Each bounding box
should be of the form [x0, y0, x1, y1, [score]].
Returns: A list of detection results. Each detection result has
bounding box of the form [x0, y0, x1, y1, [score]], and landmarks
of the form [x, y, score].
"""
image = self._load_image(image_or_path)
if boxes is None:
if self.face_detector is not None:
boxes = self._detect_faces(image)
else:
warnings.warn(
'Neither the face detector nor the bounding box is '
'specified. So the entire image is treated as the face '
'region.')
h, w = image.shape[:2]
boxes = [np.array([0, 0, w - 1, h - 1, 1])]
box_list = [{'bbox': box} for box in boxes]
return self._detect_landmarks(image, box_list)