Spaces:
Running
Running
from functools import lru_cache | |
from pathlib import Path | |
from typing import Tuple, List | |
import numpy as np | |
import onnxruntime | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
def _read_bgr_data(image) -> np.ndarray: | |
""" Read an image in BGR format. | |
Args | |
path: Path to the image. | |
""" | |
data = np.ascontiguousarray(image.convert('RGB')) | |
return data[:, :, ::-1] | |
CAFFE_MASK = np.asarray([103.939, 116.779, 123.68]).astype(np.float32) | |
def _preprocess_image(x) -> np.ndarray: | |
return x.astype(np.float32) - CAFFE_MASK | |
def _get_resize_scale(size, min_side=800, max_side=1333): | |
width, height = size | |
smallest_side = min(width, height) | |
largest_side = max(width, height) | |
scale = min_side / smallest_side | |
if largest_side * scale > max_side: | |
scale = max_side / largest_side | |
return scale | |
def preprocess_image( | |
image: Image.Image, min_side=800, max_side=1333, | |
): | |
width, height = image.size | |
scale = _get_resize_scale((width, height), min_side=min_side, max_side=max_side) | |
new_width, new_height = map(lambda x: int(x * scale), (width, height)) | |
data = _read_bgr_data(image.resize((new_width, new_height))) | |
data = _preprocess_image(data) | |
return data, scale | |
FILE_URLS = { | |
"default": { | |
"checkpoint": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_default_checkpoint.onnx'), | |
"classes": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_default_classes'), | |
}, | |
"base": { | |
"checkpoint": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_base_checkpoint.onnx'), | |
"classes": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_base_classes'), | |
}, | |
} | |
def open_model_session(model: str = 'default') -> Tuple[onnxruntime.InferenceSession, List[str]]: | |
ckpt_repo_id, ckpt_repo_file = FILE_URLS[model]['checkpoint'] | |
ckpt_file = hf_hub_download(ckpt_repo_id, ckpt_repo_file) | |
classes_repo_id, classes_repo_file = FILE_URLS[model]['classes'] | |
classes_file = hf_hub_download(classes_repo_id, classes_repo_file) | |
onnx_model = onnxruntime.InferenceSession(ckpt_file) | |
classes = [line.strip() for line in Path(classes_file).read_text().splitlines(keepends=False) if line] | |
return onnx_model, classes | |