File size: 1,066 Bytes
2023a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from functools import lru_cache
from typing import Mapping

from huggingface_hub import hf_hub_download
from imgutils.data import ImageTyping, load_image

from onnx_ import _open_onnx_model
from preprocess import _img_encode

_LABELS = ['3d', 'bangumi', 'comic', 'illustration']
_CLS_MODELS = [
    'caformer_s36',
    'caformer_s36_plus',
    'mobilenetv3',
    'mobilenetv3_dist',
    'mobilenetv3_sce',
    'mobilenetv3_sce_dist',
    'mobilevitv2_150',
]
_DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist'


@lru_cache()
def _open_anime_classify_model(model_name):
    return _open_onnx_model(hf_hub_download(
        f'deepghs/anime_classification',
        f'{model_name}/model.onnx',
    ))


def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
    image = load_image(image, mode='RGB')
    input_ = _img_encode(image, size=(size, size))[None, ...]
    output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})

    values = dict(zip(_LABELS, map(lambda x: x.item(), output[0])))
    return values