from functools import lru_cache from pathlib import Path from typing import Tuple, List import numpy as np import onnxruntime from PIL import Image from huggingface_hub import hf_hub_download def _read_bgr_data(image) -> np.ndarray: """ Read an image in BGR format. Args path: Path to the image. """ data = np.ascontiguousarray(image.convert('RGB')) return data[:, :, ::-1] CAFFE_MASK = np.asarray([103.939, 116.779, 123.68]).astype(np.float32) def _preprocess_image(x) -> np.ndarray: return x.astype(np.float32) - CAFFE_MASK def _get_resize_scale(size, min_side=800, max_side=1333): width, height = size smallest_side = min(width, height) largest_side = max(width, height) scale = min_side / smallest_side if largest_side * scale > max_side: scale = max_side / largest_side return scale def preprocess_image( image: Image.Image, min_side=800, max_side=1333, ): width, height = image.size scale = _get_resize_scale((width, height), min_side=min_side, max_side=max_side) new_width, new_height = map(lambda x: int(x * scale), (width, height)) data = _read_bgr_data(image.resize((new_width, new_height))) data = _preprocess_image(data) return data, scale FILE_URLS = { "default": { "checkpoint": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_default_checkpoint.onnx'), "classes": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_default_classes'), }, "base": { "checkpoint": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_base_checkpoint.onnx'), "classes": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_base_classes'), }, } @lru_cache() def open_model_session(model: str = 'default') -> Tuple[onnxruntime.InferenceSession, List[str]]: ckpt_repo_id, ckpt_repo_file = FILE_URLS[model]['checkpoint'] ckpt_file = hf_hub_download(ckpt_repo_id, ckpt_repo_file) classes_repo_id, classes_repo_file = FILE_URLS[model]['classes'] classes_file = hf_hub_download(classes_repo_id, classes_repo_file) onnx_model = onnxruntime.InferenceSession(ckpt_file) classes = [line.strip() for line in Path(classes_file).read_text().splitlines(keepends=False) if line] return onnx_model, classes