wd-tagger / app.py
top001's picture
Update app.py
0a85798 verified
import argparse
import os
from typing import Optional
import io
import gradio as gr
import huggingface_hub
import numpy as np
import onnxruntime as rt
import pandas as pd
from PIL import Image
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse
app = FastAPI()
TITLE = "WaifuDiffusion Tagger"
DESCRIPTION = "Demo for the WaifuDiffusion tagger models"
# Dataset v3 models
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
# Dataset v2 models
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
kaomojis = ["0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<",
"3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||"]
class Predictor:
def __init__(self):
self.model_target_size = None
self.last_loaded_repo = None
def download_model(self, model_repo):
csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME)
model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME)
return csv_path, model_path
def load_model(self, model_repo):
if model_repo == self.last_loaded_repo:
return
csv_path, model_path = self.download_model(model_repo)
tags_df = pd.read_csv(csv_path)
name_series = tags_df["name"]
name_series = name_series.map(lambda x: x.replace("_", " ") if x not in kaomojis else x)
self.tag_names = name_series.tolist()
self.rating_indexes = list(np.where(tags_df["category"] == 9)[0])
self.general_indexes = list(np.where(tags_df["category"] == 0)[0])
self.character_indexes = list(np.where(tags_df["category"] == 4)[0])
self.model = rt.InferenceSession(model_path)
_, height, width, _ = self.model.get_inputs()[0].shape
self.model_target_size = height
self.last_loaded_repo = model_repo
def prepare_image(self, image):
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
max_dim = max(image.size)
pad_left = (max_dim - image.size[0]) // 2
pad_top = (max_dim - image.size[1]) // 2
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
if max_dim != self.model_target_size:
padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC)
image_array = np.asarray(padded_image, dtype=np.float32)
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
def predict(self, image, model_repo=SWINV2_MODEL_DSV3_REPO, threshold=0.05):
self.load_model(model_repo)
image = self.prepare_image(image)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
preds = self.model.run([label_name], {input_name: image})[0]
labels = list(zip(self.tag_names, preds[0].astype(float)))
general_names = [labels[i] for i in self.general_indexes]
general_res = [x for x in general_names if x[1] > threshold]
general_res = dict(general_res)
sorted_general = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
return sorted_general, labels
predictor = Predictor()
@app.post("/tagging")
async def tagging_endpoint(
image: UploadFile = File(...),
threshold: Optional[float] = Form(0.05)
):
image_data = await image.read()
pil_image = Image.open(io.BytesIO(image_data)).convert("RGBA")
sorted_general, _ = predictor.predict(pil_image, threshold=threshold)
return JSONResponse(content={"tags": [x[0] for x in sorted_general]})
def ui_predict(
image,
model_repo,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
):
sorted_general, all_labels = predictor.predict(image, model_repo, general_thresh)
# Ratings
ratings = {all_labels[i][0]: all_labels[i][1] for i in predictor.rating_indexes}
# Characters
character_labels = [all_labels[i] for i in predictor.character_indexes]
if character_mcut_enabled:
character_probs = np.array([x[1] for x in character_labels])
character_thresh = max(0.15, np.mean(character_probs))
character_res = {x[0]: x[1] for x in character_labels if x[1] > character_thresh}
# Format output
sorted_general_strings = ", ".join(x[0] for x in sorted_general).replace("(", "\(").replace(")", "\)")
return sorted_general_strings, ratings, character_res, dict(sorted_general)
def create_demo():
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(variant="panel"):
image = gr.Image(type="pil", image_mode="RGBA", label="Input")
model_repo = gr.Dropdown(
choices=[
SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO,
VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO,
EVA02_LARGE_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO,
SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO,
CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO
],
value=SWINV2_MODEL_DSV3_REPO,
label="Model"
)
with gr.Row():
general_thresh = gr.Slider(0, 1, value=0.35, step=0.05, label="General Tags Threshold")
general_mcut = gr.Checkbox(value=False, label="Use MCut threshold")
with gr.Row():
character_thresh = gr.Slider(0, 1, value=0.85, step=0.05, label="Character Tags Threshold")
character_mcut = gr.Checkbox(value=False, label="Use MCut threshold")
submit = gr.Button(value="Submit", variant="primary")
with gr.Column(variant="panel"):
text_output = gr.Textbox(label="Output (string)")
rating_output = gr.Label(label="Rating")
character_output = gr.Label(label="Characters")
general_output = gr.Label(label="Tags")
submit.click(
ui_predict,
inputs=[image, model_repo, general_thresh, general_mcut,
character_thresh, character_mcut],
outputs=[text_output, rating_output, character_output, general_output]
)
demo.queue(max_size=10)
return demo
app = gr.mount_gradio_app(app, create_demo(), path="/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)