from functools import lru_cache from typing import Mapping from huggingface_hub import hf_hub_download from imgutils.data import ImageTyping, load_image from onnx_ import _open_onnx_model from preprocess import _img_encode _LABELS = ['3d', 'bangumi', 'comic', 'illustration'] _CLS_MODELS = [ 'caformer_s36', 'caformer_s36_plus', 'mobilenetv3', 'mobilenetv3_dist', 'mobilenetv3_sce', 'mobilenetv3_sce_dist', 'mobilevitv2_150', ] _DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist' @lru_cache() def _open_anime_classify_model(model_name): return _open_onnx_model(hf_hub_download( f'deepghs/anime_classification', f'{model_name}/model.onnx', )) def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]: image = load_image(image, mode='RGB') input_ = _img_encode(image, size=(size, size))[None, ...] output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_}) values = dict(zip(_LABELS, map(lambda x: x.item(), output[0]))) return values