deep_privacy2 / dp2 /detection /cse_mask_face_detector.py
haakohu's picture
fix
44539fc
import torch
import lzma
import tops
from pathlib import Path
from dp2.detection.base import BaseDetector
from .utils import combine_cse_maskrcnn_dets
from face_detection import build_detector as build_face_detector
from .models.cse import CSEDetector
from .models.mask_rcnn import MaskRCNNDetector
from .structures import CSEPersonDetection, VehicleDetection, FaceDetection, PersonDetection
from tops import logger
def box1_inside_box2(box1: torch.Tensor, box2: torch.Tensor):
assert len(box1.shape) == 2
assert len(box2.shape) == 2
box1_inside = torch.zeros(box1.shape[0], device=box1.device, dtype=torch.bool)
# This can be batched
for i, box in enumerate(box1):
is_outside_lefttop = (box[None, [0, 1]] <= box2[:, [0, 1]]).any(dim=1)
is_outside_rightbot = (box[None, [2, 3]] >= box2[:, [2, 3]]).any(dim=1)
is_outside = is_outside_lefttop.logical_or(is_outside_rightbot)
box1_inside[i] = is_outside.logical_not().any()
return box1_inside
class CSeMaskFaceDetector(BaseDetector):
def __init__(
self,
mask_rcnn_cfg,
face_detector_cfg: dict,
cse_cfg: dict,
face_post_process_cfg: dict,
cse_post_process_cfg,
score_threshold: float,
**kwargs
) -> None:
super().__init__(**kwargs)
self.mask_rcnn = MaskRCNNDetector(**mask_rcnn_cfg, score_thres=score_threshold)
if "confidence_threshold" not in face_detector_cfg:
face_detector_cfg["confidence_threshold"] = score_threshold
if "score_thres" not in cse_cfg:
cse_cfg["score_thres"] = score_threshold
self.cse_detector = CSEDetector(**cse_cfg)
self.face_detector = build_face_detector(**face_detector_cfg, clip_boxes=True)
self.cse_post_process_cfg = cse_post_process_cfg
self.face_mean = tops.to_cuda(torch.from_numpy(self.face_detector.mean).view(3, 1, 1))
self.mask_cse_iou_combine_threshold = self.cse_post_process_cfg.pop("iou_combine_threshold")
self.face_post_process_cfg = face_post_process_cfg
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _detect_faces(self, im: torch.Tensor):
H, W = im.shape[1:]
im = im.float() - self.face_mean
im = self.face_detector.resize(im[None], 1.0)
boxes_XYXY = self.face_detector._batched_detect(im)[0][:, :-1] # Remove score
boxes_XYXY[:, [0, 2]] *= W
boxes_XYXY[:, [1, 3]] *= H
return boxes_XYXY.round().long()
def load_from_cache(self, cache_path: Path):
logger.log(f"Loading detection from cache path: {cache_path}",)
with lzma.open(cache_path, "rb") as fp:
state_dict = torch.load(fp, map_location="cpu")
kwargs = dict(
post_process_cfg=self.cse_post_process_cfg,
embed_map=self.cse_detector.embed_map,
**self.face_post_process_cfg
)
return [
state["cls"].from_state_dict(**kwargs, state_dict=state)
for state in state_dict
]
@torch.no_grad()
def forward(self, im: torch.Tensor):
maskrcnn_dets = self.mask_rcnn(im)
cse_dets = self.cse_detector(im)
embed_map = self.cse_detector.embed_map
print("Calling face detector.")
face_boxes = self._detect_faces(im).cpu()
maskrcnn_person = {
k: v[maskrcnn_dets["is_person"]] for k, v in maskrcnn_dets.items()
}
maskrcnn_other = {
k: v[maskrcnn_dets["is_person"].logical_not()] for k, v in maskrcnn_dets.items()
}
maskrcnn_other = VehicleDetection(maskrcnn_other["segmentation"])
combined_segmentation, cse_dets, matches = combine_cse_maskrcnn_dets(
maskrcnn_person["segmentation"], cse_dets, self.mask_cse_iou_combine_threshold)
persons_with_cse = CSEPersonDetection(
combined_segmentation, cse_dets, **self.cse_post_process_cfg,
embed_map=embed_map, orig_imshape_CHW=im.shape
)
persons_with_cse.pre_process()
not_matched = [i for i in range(maskrcnn_person["segmentation"].shape[0]) if i not in matches[:, 0]]
persons_without_cse = PersonDetection(
maskrcnn_person["segmentation"][not_matched], **self.cse_post_process_cfg,
orig_imshape_CHW=im.shape
)
persons_without_cse.pre_process()
face_boxes_covered = box1_inside_box2(face_boxes, persons_with_cse.dilated_boxes).logical_or(
box1_inside_box2(face_boxes, persons_without_cse.dilated_boxes)
)
face_boxes = face_boxes[face_boxes_covered.logical_not()]
face_boxes = FaceDetection(face_boxes, **self.face_post_process_cfg)
# Order matters. The anonymizer will anonymize FIFO.
# Later detections will overwrite.
all_detections = [face_boxes, maskrcnn_other, persons_without_cse, persons_with_cse]
return all_detections