import os import cv2 import imghdr import shutil import warnings import numpy as np import gradio as gr from dataclasses import dataclass from mivolo.predictor import Predictor from utils import is_url, download_file, get_jpg_files, MODEL_DIR TMP_DIR = "./__pycache__" @dataclass class Cfg: detector_weights: str checkpoint: str device: str = "cpu" with_persons: bool = True disable_faces: bool = False draw: bool = True class ValidImgDetector: predictor = None def __init__(self): detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt" age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar" predictor_cfg = Cfg(detector_path, age_gender_path) self.predictor = Predictor(predictor_cfg) def _detect( self, image: np.ndarray, score_threshold: float, iou_threshold: float, mode: str, predictor: Predictor, ) -> np.ndarray: # input is rgb image, output must be rgb too predictor.detector.detector_kwargs["conf"] = score_threshold predictor.detector.detector_kwargs["iou"] = iou_threshold if mode == "Use persons and faces": use_persons = True disable_faces = False elif mode == "Use persons only": use_persons = True disable_faces = True elif mode == "Use faces only": use_persons = False disable_faces = False predictor.age_gender_model.meta.use_persons = use_persons predictor.age_gender_model.meta.disable_faces = disable_faces # image = image[:, :, ::-1] # RGB -> BGR detected_objects, out_im = predictor.recognize(image) has_child, has_female, has_male = False, False, False if len(detected_objects.ages) > 0: has_child = min(detected_objects.ages) < 18 has_female = "female" in detected_objects.genders has_male = "male" in detected_objects.genders return out_im[:, :, ::-1], has_child, has_female, has_male def valid_img(self, img_path): image = cv2.imread(img_path) return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor) def infer(photo: str): if is_url(photo): if os.path.exists(TMP_DIR): shutil.rmtree(TMP_DIR) photo = download_file(photo, f"{TMP_DIR}/download.jpg") detector = ValidImgDetector() if not photo or not os.path.exists(photo) or imghdr.what(photo) == None: return None, None, None, "请正确输入图片 Please input image correctly" return detector.valid_img(photo) if __name__ == "__main__": with gr.Blocks() as iface: warnings.filterwarnings("ignore") with gr.Tab("上传模式 Upload Mode"): gr.Interface( fn=infer, inputs=gr.Image(label="上传照片 Upload Photo", type="filepath"), outputs=[ gr.Image(label="检测结果 Detection Result", type="numpy"), gr.Textbox(label="存在儿童 Has Child"), gr.Textbox(label="存在女性 Has Female"), gr.Textbox(label="存在男性 Has Male"), ], examples=get_jpg_files(f"{MODEL_DIR}/examples"), allow_flagging="never", cache_examples=False, ) with gr.Tab("在线模式 Online Mode"): gr.Interface( fn=infer, inputs=gr.Textbox(label="网络图片链接 Online Picture URL"), outputs=[ gr.Image(label="检测结果 Detection Result", type="numpy"), gr.Textbox(label="存在儿童 Has Child"), gr.Textbox(label="存在女性 Has Female"), gr.Textbox(label="存在男性 Has Male"), ], allow_flagging="never", ) iface.launch()