| from typing import Optional | |
| import gradio as gr | |
| from hfutils.repository import hf_hub_repo_url | |
| from imgutils.generic import MultiLabelTIMMModel | |
| KNOWN_MODELS = ['animetimm/caformer_b36.dbv4-full', | |
| 'animetimm/caformer_m36.dbv4-full', | |
| 'animetimm/caformer_s18.dbv4-full', | |
| 'animetimm/caformer_s36.dbv4-full', | |
| 'animetimm/convnext_base.dbv4-full', | |
| 'animetimm/eva02_large_patch14_448.dbv4-full', | |
| 'animetimm/mobilenetv3_large_100.dbv4-full', | |
| 'animetimm/mobilenetv3_large_100.dbv4-full.r224', | |
| 'animetimm/mobilenetv3_large_150d.dbv4-full', | |
| 'animetimm/mobilenetv4_conv_aa_large.dbv4-full', | |
| 'animetimm/mobilenetv4_conv_small.dbv4-full', | |
| 'animetimm/mobilenetv4_conv_small_050.dbv4-full', | |
| 'animetimm/mobilevitv2_200.dbv4-full', | |
| 'animetimm/resnet18.dbv4-full', | |
| 'animetimm/resnet34.dbv4-full', | |
| 'animetimm/resnet50.dbv4-full', | |
| 'animetimm/resnet101.dbv4-full', | |
| 'animetimm/resnet152.dbv4-full', | |
| 'animetimm/swinv2_base_window8_256.dbv4-full', | |
| 'animetimm/vit_base_patch16_224.dbv4-full'] | |
| SPECIAL_MODELS = {'Recommended': 'animetimm/caformer_b36.dbv4-full', | |
| 'Lightweight': 'animetimm/mobilenetv4_conv_aa_large.dbv4-full', | |
| 'Classic EVA02': 'animetimm/eva02_large_patch14_448.dbv4-full', | |
| 'Classic SwinV2': 'animetimm/swinv2_base_window8_256.dbv4-full'} | |
| def render_model_demo(repo_id, label: Optional[str] = None): | |
| label = label or repo_id.split('/')[-1] | |
| with gr.Tab(label): | |
| model = MultiLabelTIMMModel(repo_id=repo_id) | |
| with gr.Row(): | |
| with gr.Column(): | |
| repo_url = hf_hub_repo_url(repo_id=repo_id, repo_type='model') | |
| gr.Markdown(f'This is the quick demo for tagger model [{repo_id}]({repo_url}).') | |
| with gr.Row(): | |
| model.make_ui() | |
| if __name__ == '__main__': | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML(f'<h2 style="text-align: center;">Tagger Playground For Dbv4 Full</h2>') | |
| gr.Markdown(f'This is the playground for taggers trained on [animetimm/danbooru-wdtagger-v4-w640-ws-full](https://huggingface.co/datasets/animetimm/danbooru-wdtagger-v4-w640-ws-full).' | |
| f'Powered by `dghs-imgutils`\'s quick demo module.') | |
| gr.Markdown(f'Official ranklist is on [animetimm/dbv4-full-ranklist](https://huggingface.co/spaces/animetimm/dbv4-full-ranklist).') | |
| with gr.Row(): | |
| with gr.Tabs(): | |
| _exist_models = set() | |
| for t, repo_id in SPECIAL_MODELS.items(): | |
| render_model_demo(repo_id, f'{repo_id.split("/")[-1]} ({t})') | |
| _exist_models.add(repo_id) | |
| for repo_id in KNOWN_MODELS: | |
| if repo_id not in _exist_models: | |
| render_model_demo(repo_id) | |
| _exist_models.add(repo_id) | |
| demo.launch() | |