neggles's picture
move caption into tags tab
802ae2a
raw
history blame
4.63 kB
from os import getenv
from pathlib import Path
import gradio as gr
from PIL import Image
from rich.traceback import install as traceback_install
from tagger.common import Heatmap, ImageLabels, LabelData, load_labels_hf, preprocess_image
from tagger.model import load_model_and_transform, process_heatmap
TITLE = "WD Tagger Heatmap"
DESCRIPTION = """WD Tagger v3 Heatmap Generator."""
# get HF token
HF_TOKEN = getenv("HF_TOKEN", None)
# model repo and cache
MODEL_REPO = "SmilingWolf/wd-vit-tagger-v3"
# get the repo root (or the current working directory if running in ipython)
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve()
# allowed extensions
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]
_ = traceback_install(show_locals=True, locals_max_length=0)
# get the example images
example_images = sorted(
[
str(x.relative_to(WORK_DIR))
for x in WORK_DIR.joinpath("examples").iterdir()
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS
]
)
def predict(
image: Image.Image,
threshold: float = 0.5,
):
# join variant for cache key
model, transform = load_model_and_transform(MODEL_REPO)
# load labels
labels: LabelData = load_labels_hf(MODEL_REPO)
# preprocess image
image = preprocess_image(image, (448, 448))
image = transform(image).unsqueeze(0)
# get the model output
heatmaps: list[Heatmap]
image_labels: ImageLabels
heatmaps, heatmap_grid, image_labels = process_heatmap(model, image, labels, threshold)
heatmap_images = [(x.image, x.label) for x in heatmaps]
return (
heatmap_images,
heatmap_grid,
image_labels.caption,
image_labels.booru,
image_labels.rating,
image_labels.character,
image_labels.general,
)
css = """
#use_mcut, #char_mcut {
padding-top: var(--scale-3);
}
#threshold.dimmed {
filter: brightness(75%);
}
"""
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo:
with gr.Row(equal_height=False):
with gr.Column(min_width=720):
with gr.Group():
img_input = gr.Image(
label="Input",
type="pil",
image_mode="RGB",
sources=["upload", "clipboard"],
)
with gr.Group():
with gr.Row():
threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.35,
step=0.01,
label="Tag Threshold",
scale=5,
elem_id="threshold",
)
with gr.Row():
clear = gr.ClearButton(
components=[],
variant="secondary",
size="lg",
)
submit = gr.Button(value="Submit", variant="primary", size="lg")
with gr.Column(min_width=720):
with gr.Tab(label="Heatmaps"):
heatmap_gallery = gr.Gallery(columns=3, show_label=False)
with gr.Tab(label="Grid"):
heatmap_grid = gr.Image(show_label=False)
with gr.Tab(label="Tags"):
with gr.Group():
caption = gr.Textbox(label="Caption", show_copy_button=True)
tags = gr.Textbox(label="Tags", show_copy_button=True)
with gr.Group():
rating = gr.Label(label="Rating")
with gr.Group():
character = gr.Label(label="Character")
with gr.Group():
general = gr.Label(label="General")
with gr.Row():
examples = [[imgpath, 0.35] for imgpath in example_images]
examples = gr.Examples(
examples=examples,
inputs=[img_input, threshold],
)
# tell clear button which components to clear
clear.add([img_input, heatmap_gallery, heatmap_grid, caption, tags, rating, character, general])
submit.click(
predict,
inputs=[img_input, threshold],
outputs=[heatmap_gallery, heatmap_grid, caption, tags, rating, character, general],
api_name="predict",
)
if __name__ == "__main__":
demo.queue(max_size=10)
if getenv("SPACE_ID", None) is not None:
demo.launch()
else:
demo.launch(
server_name="0.0.0.0",
server_port=7871,
debug=True,
)