JoJo_Style_Transfer / inference /face_detector.py
Podtekatel's picture
Fix error for no objects by type checking
7a92592
import os
from abc import ABC, abstractmethod
from typing import List
import cv2
import numpy as np
from retinaface import RetinaFace
from retinaface.model import retinaface_model
from .box_utils import convert_to_square
class FaceDetector(ABC):
def __init__(self, target_size):
self.target_size = target_size
@abstractmethod
def detect_crops(self, img, *args, **kwargs) -> List[np.ndarray]:
"""
Img is a numpy ndarray in range [0..255], uint8 dtype, RGB type
Returns ndarray with [x1, y1, x2, y2] in row
"""
pass
@abstractmethod
def postprocess_crops(self, crops, *args, **kwargs) -> List[np.ndarray]:
return crops
def sort_faces(self, crops):
sorted_faces = sorted(crops, key=lambda x: -(x[2] - x[0]) * (x[3] - x[1]))
sorted_faces = np.stack(sorted_faces, axis=0)
return sorted_faces
def fix_range_crops(self, img, crops):
H, W, _ = img.shape
final_crops = []
for crop in crops:
x1, y1, x2, y2 = crop
x1 = max(min(round(x1), W), 0)
y1 = max(min(round(y1), H), 0)
x2 = max(min(round(x2), W), 0)
y2 = max(min(round(y2), H), 0)
new_crop = [x1, y1, x2, y2]
final_crops.append(new_crop)
final_crops = np.array(final_crops, dtype=np.int)
return final_crops
def crop_faces(self, img, crops) -> List[np.ndarray]:
cropped_faces = []
for crop in crops:
x1, y1, x2, y2 = crop
face_crop = img[y1:y2, x1:x2, :]
cropped_faces.append(face_crop)
return cropped_faces
def unify_and_merge(self, cropped_images):
return cropped_images
def __call__(self, img):
return self.detect_faces(img)
def detect_faces(self, img):
crops = self.detect_crops(img)
if crops is None or len(crops) == 0:
return [], []
crops = self.sort_faces(crops)
updated_crops = self.postprocess_crops(crops)
updated_crops = self.fix_range_crops(img, updated_crops)
cropped_faces = self.crop_faces(img, updated_crops)
unified_faces = self.unify_and_merge(cropped_faces)
return unified_faces, updated_crops
class StatRetinaFaceDetector(FaceDetector):
def __init__(self, target_size=None):
super().__init__(target_size)
self.model = retinaface_model.build_model()
#self.relative_offsets = [0.3258, 0.5225, 0.3258, 0.1290]
self.relative_offsets = [0.3619, 0.5830, 0.3619, 0.1909]
def postprocess_crops(self, crops, *args, **kwargs) -> np.ndarray:
final_crops = []
x1_offset, y1_offset, x2_offset, y2_offset = self.relative_offsets
for crop in crops:
x1, y1, x2, y2 = crop
w, h = x2 - x1, y2 - y1
x1 -= w * x1_offset
y1 -= h * y1_offset
x2 += w * x2_offset
y2 += h * y2_offset
crop = np.array([x1, y1, x2, y2], dtype=crop.dtype)
crop = convert_to_square(crop)
final_crops.append(crop)
final_crops = np.stack(final_crops, axis=0)
return final_crops
def detect_crops(self, img, *args, **kwargs):
faces = RetinaFace.detect_faces(img, model=self.model)
crops = []
if isinstance(faces, tuple):
faces = {}
for name, face in faces.items():
x1, y1, x2, y2 = face['facial_area']
crop = np.array([x1, y1, x2, y2])
crops.append(crop)
if len(crops) > 0:
crops = np.stack(crops, axis=0)
return crops
def unify_and_merge(self, cropped_images):
if self.target_size is None:
return cropped_images
else:
resized_images = []
for cropped_image in cropped_images:
resized_image = cv2.resize(cropped_image, (self.target_size, self.target_size),
interpolation=cv2.INTER_LINEAR)
resized_images.append(resized_image)
resized_images = np.stack(resized_images, axis=0)
return resized_images