File size: 5,849 Bytes
906e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from __future__ import annotations

import pathlib
import warnings
from typing import Optional, Union

import cv2
import mmcv
import numpy as np
import torch.nn as nn
from mmdet.apis import inference_detector, init_detector
from mmpose.apis import inference_top_down_pose_model, init_pose_model
from mmpose.datasets import DatasetInfo


class LandmarkDetector:
    def __init__(
            self,
            landmark_detector_config_or_path: Union[mmcv.Config, str,
                                                    pathlib.Path],
            landmark_detector_checkpoint_path: Union[str, pathlib.Path],
            face_detector_config_or_path: Optional[Union[mmcv.Config, str,
                                                         pathlib.Path]] = None,
            face_detector_checkpoint_path: Optional[Union[
                str, pathlib.Path]] = None,
            device: str = 'cuda:0',
            flip_test: bool = True,
            box_scale_factor: float = 1.1):
        landmark_config = self._load_config(landmark_detector_config_or_path)
        self.dataset_info = DatasetInfo(
            landmark_config.dataset_info)  # type: ignore
        face_detector_config = self._load_config(face_detector_config_or_path)

        self.landmark_detector = self._init_pose_model(
            landmark_config, landmark_detector_checkpoint_path, device,
            flip_test)
        self.face_detector = self._init_face_detector(
            face_detector_config, face_detector_checkpoint_path, device)

        self.box_scale_factor = box_scale_factor

    @staticmethod
    def _load_config(
        config_or_path: Optional[Union[mmcv.Config, str, pathlib.Path]]
    ) -> Optional[mmcv.Config]:
        if config_or_path is None or isinstance(config_or_path, mmcv.Config):
            return config_or_path
        return mmcv.Config.fromfile(config_or_path)

    @staticmethod
    def _init_pose_model(config: mmcv.Config,
                         checkpoint_path: Union[str, pathlib.Path],
                         device: str, flip_test: bool) -> nn.Module:
        if isinstance(checkpoint_path, pathlib.Path):
            checkpoint_path = checkpoint_path.as_posix()
        model = init_pose_model(config, checkpoint_path, device=device)
        model.cfg.model.test_cfg.flip_test = flip_test
        return model

    @staticmethod
    def _init_face_detector(config: Optional[mmcv.Config],
                            checkpoint_path: Optional[Union[str,
                                                            pathlib.Path]],
                            device: str) -> Optional[nn.Module]:
        if config is not None:
            if isinstance(checkpoint_path, pathlib.Path):
                checkpoint_path = checkpoint_path.as_posix()
            model = init_detector(config, checkpoint_path, device=device)
        else:
            model = None
        return model

    def _detect_faces(self, image: np.ndarray) -> list[np.ndarray]:
        # predicted boxes using mmdet model have the format of
        # [x0, y0, x1, y1, score]
        boxes = inference_detector(self.face_detector, image)[0]
        # scale boxes by `self.box_scale_factor`
        boxes = self._update_pred_box(boxes)
        return boxes

    def _update_pred_box(self, pred_boxes: np.ndarray) -> list[np.ndarray]:
        boxes = []
        for pred_box in pred_boxes:
            box = pred_box[:4]
            size = box[2:] - box[:2] + 1
            new_size = size * self.box_scale_factor
            center = (box[:2] + box[2:]) / 2
            tl = center - new_size / 2
            br = tl + new_size
            pred_box[:4] = np.concatenate([tl, br])
            boxes.append(pred_box)
        return boxes

    def _detect_landmarks(
            self, image: np.ndarray,
            boxes: list[dict[str, np.ndarray]]) -> list[dict[str, np.ndarray]]:
        preds, _ = inference_top_down_pose_model(
            self.landmark_detector,
            image,
            boxes,
            format='xyxy',
            dataset_info=self.dataset_info,
            return_heatmap=False)
        return preds

    @staticmethod
    def _load_image(
            image_or_path: Union[np.ndarray, str, pathlib.Path]) -> np.ndarray:
        if isinstance(image_or_path, np.ndarray):
            image = image_or_path
        elif isinstance(image_or_path, str):
            image = cv2.imread(image_or_path)
        elif isinstance(image_or_path, pathlib.Path):
            image = cv2.imread(image_or_path.as_posix())
        else:
            raise ValueError
        return image

    def __call__(
        self,
        image_or_path: Union[np.ndarray, str, pathlib.Path],
        boxes: Optional[list[np.ndarray]] = None
    ) -> list[dict[str, np.ndarray]]:
        """Detect face landmarks.

        Args:
            image_or_path: An image with BGR channel order or an image path.
            boxes: A list of bounding boxes for faces. Each bounding box
                should be of the form [x0, y0, x1, y1, [score]].

        Returns: A list of detection results. Each detection result has
            bounding box of the form [x0, y0, x1, y1, [score]], and landmarks
            of the form [x, y, score].
        """
        image = self._load_image(image_or_path)
        if boxes is None:
            if self.face_detector is not None:
                boxes = self._detect_faces(image)
            else:
                warnings.warn(
                    'Neither the face detector nor the bounding box is '
                    'specified. So the entire image is treated as the face '
                    'region.')
                h, w = image.shape[:2]
                boxes = [np.array([0, 0, w - 1, h - 1, 1])]
        box_list = [{'bbox': box} for box in boxes]
        return self._detect_landmarks(image, box_list)