|
import cv2 |
|
import shutil |
|
import numpy as np |
|
from dataclasses import dataclass |
|
from tqdm import tqdm |
|
from mivolo.predictor import Predictor |
|
from utils import * |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
@dataclass |
|
class Cfg: |
|
detector_weights: str |
|
checkpoint: str |
|
device: str = "cuda" |
|
with_persons: bool = True |
|
disable_faces: bool = False |
|
draw: bool = True |
|
|
|
|
|
class ValidImgDetector: |
|
|
|
predictor = None |
|
|
|
def __init__(self): |
|
detector_path = "./model/yolov8x_person_face.pt" |
|
age_gender_path = "./model/model_imdb_cross_person_4.22_99.46.pth.tar" |
|
predictor_cfg = Cfg(detector_path, age_gender_path) |
|
self.predictor = Predictor(predictor_cfg) |
|
|
|
def _detect( |
|
self, |
|
image: np.ndarray, |
|
score_threshold: float, |
|
iou_threshold: float, |
|
mode: str, |
|
predictor: Predictor |
|
) -> np.ndarray: |
|
|
|
predictor.detector.detector_kwargs['conf'] = score_threshold |
|
predictor.detector.detector_kwargs['iou'] = iou_threshold |
|
|
|
if mode == "Use persons and faces": |
|
use_persons = True |
|
disable_faces = False |
|
elif mode == "Use persons only": |
|
use_persons = True |
|
disable_faces = True |
|
elif mode == "Use faces only": |
|
use_persons = False |
|
disable_faces = False |
|
|
|
predictor.age_gender_model.meta.use_persons = use_persons |
|
predictor.age_gender_model.meta.disable_faces = disable_faces |
|
|
|
|
|
detected_objects, _ = predictor.recognize(image) |
|
|
|
has_child, has_female, has_male = False, False, False |
|
if len(detected_objects.ages) > 0: |
|
has_child = min(detected_objects.ages) < 18 |
|
has_female = 'female' in detected_objects.genders |
|
has_male = 'male' in detected_objects.genders |
|
|
|
return has_child, has_female, has_male |
|
|
|
def valid_img(self, img_path): |
|
image = cv2.imread(img_path) |
|
has_child, has_female, has_male = self._detect( |
|
image, 0.4, 0.7, "Use persons and faces", self.predictor) |
|
return (not has_child) and (has_female) and (not has_male) |
|
|
|
|
|
def filter_img(): |
|
detector = ValidImgDetector() |
|
create_dir('./output/valid') |
|
create_dir('./output/invalid') |
|
|
|
for _, _, files in os.walk('./images'): |
|
for file in tqdm(files): |
|
if file.endswith('.jpg'): |
|
src_path = f"./images/{file}" |
|
dst_path = "./output/invalid" |
|
if detector.valid_img(src_path): |
|
dst_path = "./output/valid" |
|
|
|
shutil.move(src_path, dst_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
filter_img() |
|
|