|
from functools import lru_cache |
|
from typing import Mapping |
|
|
|
from huggingface_hub import hf_hub_download |
|
from imgutils.data import ImageTyping, load_image |
|
|
|
from onnx_ import _open_onnx_model |
|
from preprocess import _img_encode |
|
|
|
_LABELS = ['3d', 'bangumi', 'comic', 'illustration'] |
|
_CLS_MODELS = [ |
|
'caformer_s36', |
|
'caformer_s36_plus', |
|
'mobilenetv3', |
|
'mobilenetv3_dist', |
|
'mobilenetv3_sce', |
|
'mobilenetv3_sce_dist', |
|
'mobilevitv2_150', |
|
] |
|
_DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist' |
|
|
|
|
|
@lru_cache() |
|
def _open_anime_classify_model(model_name): |
|
return _open_onnx_model(hf_hub_download( |
|
f'deepghs/anime_classification', |
|
f'{model_name}/model.onnx', |
|
)) |
|
|
|
|
|
def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]: |
|
image = load_image(image, mode='RGB') |
|
input_ = _img_encode(image, size=(size, size))[None, ...] |
|
output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_}) |
|
|
|
values = dict(zip(_LABELS, map(lambda x: x.item(), output[0]))) |
|
return values |
|
|