auto_image_censor / nudenet.py
narugo1992
dev(narugo): remove usage of cv2
08d41dd
from functools import lru_cache
from pathlib import Path
from typing import Tuple, List
import numpy as np
import onnxruntime
from PIL import Image
from huggingface_hub import hf_hub_download
def _read_bgr_data(image) -> np.ndarray:
""" Read an image in BGR format.
Args
path: Path to the image.
"""
data = np.ascontiguousarray(image.convert('RGB'))
return data[:, :, ::-1]
CAFFE_MASK = np.asarray([103.939, 116.779, 123.68]).astype(np.float32)
def _preprocess_image(x) -> np.ndarray:
return x.astype(np.float32) - CAFFE_MASK
def _get_resize_scale(size, min_side=800, max_side=1333):
width, height = size
smallest_side = min(width, height)
largest_side = max(width, height)
scale = min_side / smallest_side
if largest_side * scale > max_side:
scale = max_side / largest_side
return scale
def preprocess_image(
image: Image.Image, min_side=800, max_side=1333,
):
width, height = image.size
scale = _get_resize_scale((width, height), min_side=min_side, max_side=max_side)
new_width, new_height = map(lambda x: int(x * scale), (width, height))
data = _read_bgr_data(image.resize((new_width, new_height)))
data = _preprocess_image(data)
return data, scale
FILE_URLS = {
"default": {
"checkpoint": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_default_checkpoint.onnx'),
"classes": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_default_classes'),
},
"base": {
"checkpoint": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_base_checkpoint.onnx'),
"classes": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_base_classes'),
},
}
@lru_cache()
def open_model_session(model: str = 'default') -> Tuple[onnxruntime.InferenceSession, List[str]]:
ckpt_repo_id, ckpt_repo_file = FILE_URLS[model]['checkpoint']
ckpt_file = hf_hub_download(ckpt_repo_id, ckpt_repo_file)
classes_repo_id, classes_repo_file = FILE_URLS[model]['classes']
classes_file = hf_hub_download(classes_repo_id, classes_repo_file)
onnx_model = onnxruntime.InferenceSession(ckpt_file)
classes = [line.strip() for line in Path(classes_file).read_text().splitlines(keepends=False) if line]
return onnx_model, classes