narugo1992
dev(narugo): init commit
2023a9f
raw
history blame
1.07 kB
from functools import lru_cache
from typing import Mapping
from huggingface_hub import hf_hub_download
from imgutils.data import ImageTyping, load_image
from onnx_ import _open_onnx_model
from preprocess import _img_encode
_LABELS = ['3d', 'bangumi', 'comic', 'illustration']
_CLS_MODELS = [
'caformer_s36',
'caformer_s36_plus',
'mobilenetv3',
'mobilenetv3_dist',
'mobilenetv3_sce',
'mobilenetv3_sce_dist',
'mobilevitv2_150',
]
_DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist'
@lru_cache()
def _open_anime_classify_model(model_name):
return _open_onnx_model(hf_hub_download(
f'deepghs/anime_classification',
f'{model_name}/model.onnx',
))
def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
image = load_image(image, mode='RGB')
input_ = _img_encode(image, size=(size, size))[None, ...]
output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_})
values = dict(zip(_LABELS, map(lambda x: x.item(), output[0])))
return values