human-detector / gender_age.py
George
upl all codes
b5f33fd
raw
history blame
No virus
2.73 kB
import cv2
import shutil
import numpy as np
from dataclasses import dataclass
from tqdm import tqdm
from mivolo.predictor import Predictor
from utils import *
import warnings
warnings.filterwarnings("ignore")
@dataclass
class Cfg:
detector_weights: str
checkpoint: str
device: str = "cuda"
with_persons: bool = True
disable_faces: bool = False
draw: bool = True
class ValidImgDetector:
predictor = None
def __init__(self):
detector_path = "./model/yolov8x_person_face.pt"
age_gender_path = "./model/model_imdb_cross_person_4.22_99.46.pth.tar"
predictor_cfg = Cfg(detector_path, age_gender_path)
self.predictor = Predictor(predictor_cfg)
def _detect(
self,
image: np.ndarray,
score_threshold: float,
iou_threshold: float,
mode: str,
predictor: Predictor
) -> np.ndarray:
# input is rgb image, output must be rgb too
predictor.detector.detector_kwargs['conf'] = score_threshold
predictor.detector.detector_kwargs['iou'] = iou_threshold
if mode == "Use persons and faces":
use_persons = True
disable_faces = False
elif mode == "Use persons only":
use_persons = True
disable_faces = True
elif mode == "Use faces only":
use_persons = False
disable_faces = False
predictor.age_gender_model.meta.use_persons = use_persons
predictor.age_gender_model.meta.disable_faces = disable_faces
image = image[:, :, ::-1] # RGB -> BGR
detected_objects, _ = predictor.recognize(image)
has_child, has_female, has_male = False, False, False
if len(detected_objects.ages) > 0:
has_child = min(detected_objects.ages) < 18
has_female = 'female' in detected_objects.genders
has_male = 'male' in detected_objects.genders
return has_child, has_female, has_male
def valid_img(self, img_path):
image = cv2.imread(img_path)
has_child, has_female, has_male = self._detect(
image, 0.4, 0.7, "Use persons and faces", self.predictor)
return (not has_child) and (has_female) and (not has_male)
if __name__ == "__main__":
detector = ValidImgDetector()
create_dir('./output/valid')
create_dir('./output/invalid')
for root, _, files in os.walk('./images'):
for file in tqdm(files):
if file.endswith('.jpg'):
src_path = f"./images/{file}"
dst_path = "./output/invalid"
if detector.valid_img(src_path):
dst_path = "./output/valid"
shutil.move(src_path, dst_path)