File size: 2,285 Bytes
d3adc66
349bed6
 
 
 
 
 
 
 
 
08d41dd
349bed6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08d41dd
 
 
 
349bed6
 
 
 
 
 
 
 
 
 
 
08d41dd
 
 
 
 
 
349bed6
 
 
 
d3adc66
 
349bed6
 
d3adc66
 
349bed6
 
 
 
 
 
d3adc66
 
 
 
349bed6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from functools import lru_cache
from pathlib import Path
from typing import Tuple, List

import numpy as np
import onnxruntime
from PIL import Image
from huggingface_hub import hf_hub_download


def _read_bgr_data(image) -> np.ndarray:
    """ Read an image in BGR format.
    Args
        path: Path to the image.
    """
    data = np.ascontiguousarray(image.convert('RGB'))
    return data[:, :, ::-1]


CAFFE_MASK = np.asarray([103.939, 116.779, 123.68]).astype(np.float32)


def _preprocess_image(x) -> np.ndarray:
    return x.astype(np.float32) - CAFFE_MASK


def _get_resize_scale(size, min_side=800, max_side=1333):
    width, height = size
    smallest_side = min(width, height)
    largest_side = max(width, height)

    scale = min_side / smallest_side
    if largest_side * scale > max_side:
        scale = max_side / largest_side

    return scale


def preprocess_image(
        image: Image.Image, min_side=800, max_side=1333,
):
    width, height = image.size
    scale = _get_resize_scale((width, height), min_side=min_side, max_side=max_side)
    new_width, new_height = map(lambda x: int(x * scale), (width, height))
    data = _read_bgr_data(image.resize((new_width, new_height)))
    data = _preprocess_image(data)
    return data, scale


FILE_URLS = {
    "default": {
        "checkpoint": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_default_checkpoint.onnx'),
        "classes": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_default_classes'),
    },
    "base": {
        "checkpoint": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_base_checkpoint.onnx'),
        "classes": ('narugo/gchar_models', 'nudenet/baseline/detector_v2_base_classes'),
    },
}


@lru_cache()
def open_model_session(model: str = 'default') -> Tuple[onnxruntime.InferenceSession, List[str]]:
    ckpt_repo_id, ckpt_repo_file = FILE_URLS[model]['checkpoint']
    ckpt_file = hf_hub_download(ckpt_repo_id, ckpt_repo_file)
    classes_repo_id, classes_repo_file = FILE_URLS[model]['classes']
    classes_file = hf_hub_download(classes_repo_id, classes_repo_file)

    onnx_model = onnxruntime.InferenceSession(ckpt_file)
    classes = [line.strip() for line in Path(classes_file).read_text().splitlines(keepends=False) if line]
    return onnx_model, classes