Spaces:
Running
Running
| from os import getenv | |
| from pathlib import Path | |
| import gradio as gr | |
| from PIL import Image | |
| from rich.traceback import install as traceback_install | |
| from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image | |
| from tagger.model import load_model_and_transform, process_heatmap | |
| TITLE = "WD Tagger Heatmap" | |
| DESCRIPTION = """WD Tagger v3 Heatmap Generator.""" | |
| # get HF token | |
| HF_TOKEN = getenv("HF_TOKEN", None) | |
| # model repo and cache | |
| MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3" | |
| # get the repo root (or the current working directory if running in ipython) | |
| WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve() | |
| # allowed extensions | |
| IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] | |
| _ = traceback_install(show_locals=True, locals_max_length=0) | |
| # get the example images | |
| example_images = sorted( | |
| [ | |
| str(x.relative_to(WORK_DIR)) | |
| for x in WORK_DIR.joinpath("examples").iterdir() | |
| if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS | |
| ] | |
| ) | |
| def predict( | |
| image: Image.Image, | |
| threshold: float = 0.5, | |
| ): | |
| # join variant for cache key | |
| model, transform = load_model_and_transform(MODEL_REPO) | |
| # load labels | |
| labels: LabelData = load_labels_hf(MODEL_REPO) | |
| # preprocess image | |
| image = preprocess_image(image, (448, 448)) | |
| image = transform(image).unsqueeze(0) | |
| # get the model output | |
| heatmaps: list[Heatmap] | |
| image_labels: ImageLabels | |
| heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold) | |
| heatmap_images = [(x.image, x.label) for x in heatmaps] | |
| return ( | |
| heatmap_images, | |
| heatmap_grid, | |
| image_labels.caption, | |
| image_labels.booru, | |
| image_labels.rating, | |
| image_labels.character, | |
| image_labels.general, | |
| ) | |
| css = """ | |
| #use_mcut, #char_mcut { | |
| padding-top: var(--scale-3); | |
| } | |
| #threshold.dimmed { | |
| filter: brightness(75%); | |
| } | |
| """ | |
| with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo: | |
| with gr.Row(equal_height=False): | |
| with gr.Column(min_width=720): | |
| with gr.Group(): | |
| img_input = gr.Image( | |
| label="Input", | |
| type="pil", | |
| image_mode="RGB", | |
| sources=["upload", "clipboard"], | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.35, | |
| step=0.01, | |
| label="Tag Threshold", | |
| scale=5, | |
| elem_id="threshold", | |
| ) | |
| with gr.Row(): | |
| clear = gr.ClearButton( | |
| components=[], | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| submit = gr.Button(value="Submit", variant="primary", size="lg") | |
| with gr.Column(min_width=720): | |
| with gr.Tab(label="Heatmaps"): | |
| heatmap_gallery = gr.Gallery(columns=3, show_label=False) | |
| with gr.Tab(label="Grid"): | |
| heatmap_grid = gr.Image(show_label=False) | |
| with gr.Tab(label="Tags"): | |
| with gr.Group(): | |
| caption = gr.Textbox(label="Caption", show_copy_button=True) | |
| tags = gr.Textbox(label="Tags", show_copy_button=True) | |
| with gr.Group(): | |
| rating = gr.Label(label="Rating") | |
| with gr.Group(): | |
| character = gr.Label(label="Character") | |
| with gr.Group(): | |
| general = gr.Label(label="General") | |
| with gr.Row(): | |
| examples = [[imgpath, 0.35] for imgpath in example_images] | |
| examples = gr.Examples( | |
| examples=examples, | |
| inputs=[img_input, threshold], | |
| ) | |
| # tell clear button which components to clear | |
| clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general]) | |
| submit.click( | |
| predict, | |
| inputs=[img_input, threshold], | |
| outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general], | |
| api_name="predict", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10) | |
| if getenv("SPACE_ID", None) is not None: | |
| demo.launch() | |
| else: | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7871, | |
| debug=True, | |
| ) | |