import cv2 import shutil import numpy as np from dataclasses import dataclass from tqdm import tqdm from mivolo.predictor import Predictor from utils import * import warnings warnings.filterwarnings("ignore") @dataclass class Cfg: detector_weights: str checkpoint: str device: str = "cuda" with_persons: bool = True disable_faces: bool = False draw: bool = True class ValidImgDetector: predictor = None def __init__(self): detector_path = "./model/yolov8x_person_face.pt" age_gender_path = "./model/model_imdb_cross_person_4.22_99.46.pth.tar" predictor_cfg = Cfg(detector_path, age_gender_path) self.predictor = Predictor(predictor_cfg) def _detect( self, image: np.ndarray, score_threshold: float, iou_threshold: float, mode: str, predictor: Predictor ) -> np.ndarray: # input is rgb image, output must be rgb too predictor.detector.detector_kwargs['conf'] = score_threshold predictor.detector.detector_kwargs['iou'] = iou_threshold if mode == "Use persons and faces": use_persons = True disable_faces = False elif mode == "Use persons only": use_persons = True disable_faces = True elif mode == "Use faces only": use_persons = False disable_faces = False predictor.age_gender_model.meta.use_persons = use_persons predictor.age_gender_model.meta.disable_faces = disable_faces image = image[:, :, ::-1] # RGB -> BGR detected_objects, _ = predictor.recognize(image) has_child, has_female, has_male = False, False, False if len(detected_objects.ages) > 0: has_child = min(detected_objects.ages) < 18 has_female = 'female' in detected_objects.genders has_male = 'male' in detected_objects.genders return has_child, has_female, has_male def valid_img(self, img_path): image = cv2.imread(img_path) has_child, has_female, has_male = self._detect( image, 0.4, 0.7, "Use persons and faces", self.predictor) return (not has_child) and (has_female) and (not has_male) if __name__ == "__main__": detector = ValidImgDetector() create_dir('./output/valid') create_dir('./output/invalid') for root, _, files in os.walk('./images'): for file in tqdm(files): if file.endswith('.jpg'): src_path = f"./images/{file}" dst_path = "./output/invalid" if detector.valid_img(src_path): dst_path = "./output/valid" shutil.move(src_path, dst_path)