from typing import Optional import gradio as gr from hfutils.repository import hf_hub_repo_url from imgutils.generic import MultiLabelTIMMModel KNOWN_MODELS = ['animetimm/caformer_b36.dbv4-full', 'animetimm/caformer_m36.dbv4-full', 'animetimm/caformer_s18.dbv4-full', 'animetimm/caformer_s36.dbv4-full', 'animetimm/convnext_base.dbv4-full', 'animetimm/eva02_large_patch14_448.dbv4-full', 'animetimm/mobilenetv3_large_100.dbv4-full', 'animetimm/mobilenetv3_large_100.dbv4-full.r224', 'animetimm/mobilenetv3_large_150d.dbv4-full', 'animetimm/mobilenetv4_conv_aa_large.dbv4-full', 'animetimm/mobilenetv4_conv_small.dbv4-full', 'animetimm/mobilenetv4_conv_small_050.dbv4-full', 'animetimm/mobilevitv2_200.dbv4-full', 'animetimm/resnet18.dbv4-full', 'animetimm/resnet34.dbv4-full', 'animetimm/resnet50.dbv4-full', 'animetimm/resnet101.dbv4-full', 'animetimm/resnet152.dbv4-full', 'animetimm/swinv2_base_window8_256.dbv4-full', 'animetimm/vit_base_patch16_224.dbv4-full'] SPECIAL_MODELS = {'Recommended': 'animetimm/caformer_b36.dbv4-full', 'Lightweight': 'animetimm/mobilenetv4_conv_aa_large.dbv4-full', 'Classic EVA02': 'animetimm/eva02_large_patch14_448.dbv4-full', 'Classic SwinV2': 'animetimm/swinv2_base_window8_256.dbv4-full'} def render_model_demo(repo_id, label: Optional[str] = None): label = label or repo_id.split('/')[-1] with gr.Tab(label): model = MultiLabelTIMMModel(repo_id=repo_id) with gr.Row(): with gr.Column(): repo_url = hf_hub_repo_url(repo_id=repo_id, repo_type='model') gr.Markdown(f'This is the quick demo for tagger model [{repo_id}]({repo_url}).') with gr.Row(): model.make_ui() if __name__ == '__main__': with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr.HTML(f'