|
import argparse |
|
import os |
|
from typing import Optional |
|
import io |
|
|
|
import gradio as gr |
|
import huggingface_hub |
|
import numpy as np |
|
import onnxruntime as rt |
|
import pandas as pd |
|
from PIL import Image |
|
from fastapi import FastAPI, File, UploadFile, Form |
|
from fastapi.responses import JSONResponse |
|
|
|
app = FastAPI() |
|
|
|
TITLE = "WaifuDiffusion Tagger" |
|
DESCRIPTION = "Demo for the WaifuDiffusion tagger models" |
|
|
|
|
|
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" |
|
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" |
|
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" |
|
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" |
|
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" |
|
|
|
|
|
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" |
|
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" |
|
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" |
|
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" |
|
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" |
|
|
|
MODEL_FILENAME = "model.onnx" |
|
LABEL_FILENAME = "selected_tags.csv" |
|
|
|
kaomojis = ["0_0", "(o)_(o)", "+_+", "+_-", "._.", "<o>_<o>", "<|>_<|>", "=_=", ">_<", |
|
"3_3", "6_9", ">_o", "@_@", "^_^", "o_o", "u_u", "x_x", "|_|", "||_||"] |
|
|
|
class Predictor: |
|
def __init__(self): |
|
self.model_target_size = None |
|
self.last_loaded_repo = None |
|
|
|
def download_model(self, model_repo): |
|
csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME) |
|
model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME) |
|
return csv_path, model_path |
|
|
|
def load_model(self, model_repo): |
|
if model_repo == self.last_loaded_repo: |
|
return |
|
|
|
csv_path, model_path = self.download_model(model_repo) |
|
tags_df = pd.read_csv(csv_path) |
|
name_series = tags_df["name"] |
|
name_series = name_series.map(lambda x: x.replace("_", " ") if x not in kaomojis else x) |
|
|
|
self.tag_names = name_series.tolist() |
|
self.rating_indexes = list(np.where(tags_df["category"] == 9)[0]) |
|
self.general_indexes = list(np.where(tags_df["category"] == 0)[0]) |
|
self.character_indexes = list(np.where(tags_df["category"] == 4)[0]) |
|
|
|
self.model = rt.InferenceSession(model_path) |
|
_, height, width, _ = self.model.get_inputs()[0].shape |
|
self.model_target_size = height |
|
self.last_loaded_repo = model_repo |
|
|
|
def prepare_image(self, image): |
|
canvas = Image.new("RGBA", image.size, (255, 255, 255)) |
|
canvas.alpha_composite(image) |
|
image = canvas.convert("RGB") |
|
|
|
max_dim = max(image.size) |
|
pad_left = (max_dim - image.size[0]) // 2 |
|
pad_top = (max_dim - image.size[1]) // 2 |
|
|
|
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) |
|
padded_image.paste(image, (pad_left, pad_top)) |
|
|
|
if max_dim != self.model_target_size: |
|
padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC) |
|
|
|
image_array = np.asarray(padded_image, dtype=np.float32) |
|
image_array = image_array[:, :, ::-1] |
|
|
|
return np.expand_dims(image_array, axis=0) |
|
|
|
def predict(self, image, model_repo=SWINV2_MODEL_DSV3_REPO, threshold=0.05): |
|
self.load_model(model_repo) |
|
|
|
image = self.prepare_image(image) |
|
input_name = self.model.get_inputs()[0].name |
|
label_name = self.model.get_outputs()[0].name |
|
preds = self.model.run([label_name], {input_name: image})[0] |
|
|
|
labels = list(zip(self.tag_names, preds[0].astype(float))) |
|
general_names = [labels[i] for i in self.general_indexes] |
|
general_res = [x for x in general_names if x[1] > threshold] |
|
general_res = dict(general_res) |
|
|
|
sorted_general = sorted(general_res.items(), key=lambda x: x[1], reverse=True) |
|
return sorted_general, labels |
|
|
|
predictor = Predictor() |
|
|
|
@app.post("/tagging") |
|
async def tagging_endpoint( |
|
image: UploadFile = File(...), |
|
threshold: Optional[float] = Form(0.05) |
|
): |
|
image_data = await image.read() |
|
pil_image = Image.open(io.BytesIO(image_data)).convert("RGBA") |
|
sorted_general, _ = predictor.predict(pil_image, threshold=threshold) |
|
return JSONResponse(content={"tags": [x[0] for x in sorted_general]}) |
|
|
|
def ui_predict( |
|
image, |
|
model_repo, |
|
general_thresh, |
|
general_mcut_enabled, |
|
character_thresh, |
|
character_mcut_enabled, |
|
): |
|
sorted_general, all_labels = predictor.predict(image, model_repo, general_thresh) |
|
|
|
|
|
ratings = {all_labels[i][0]: all_labels[i][1] for i in predictor.rating_indexes} |
|
|
|
|
|
character_labels = [all_labels[i] for i in predictor.character_indexes] |
|
if character_mcut_enabled: |
|
character_probs = np.array([x[1] for x in character_labels]) |
|
character_thresh = max(0.15, np.mean(character_probs)) |
|
character_res = {x[0]: x[1] for x in character_labels if x[1] > character_thresh} |
|
|
|
|
|
sorted_general_strings = ", ".join(x[0] for x in sorted_general).replace("(", "\(").replace(")", "\)") |
|
return sorted_general_strings, ratings, character_res, dict(sorted_general) |
|
|
|
def create_demo(): |
|
with gr.Blocks(title=TITLE) as demo: |
|
gr.Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>") |
|
gr.Markdown(DESCRIPTION) |
|
|
|
with gr.Row(): |
|
with gr.Column(variant="panel"): |
|
image = gr.Image(type="pil", image_mode="RGBA", label="Input") |
|
model_repo = gr.Dropdown( |
|
choices=[ |
|
SWINV2_MODEL_DSV3_REPO, CONV_MODEL_DSV3_REPO, |
|
VIT_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, |
|
EVA02_LARGE_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO, |
|
SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, |
|
CONV2_MODEL_DSV2_REPO, VIT_MODEL_DSV2_REPO |
|
], |
|
value=SWINV2_MODEL_DSV3_REPO, |
|
label="Model" |
|
) |
|
with gr.Row(): |
|
general_thresh = gr.Slider(0, 1, value=0.35, step=0.05, label="General Tags Threshold") |
|
general_mcut = gr.Checkbox(value=False, label="Use MCut threshold") |
|
with gr.Row(): |
|
character_thresh = gr.Slider(0, 1, value=0.85, step=0.05, label="Character Tags Threshold") |
|
character_mcut = gr.Checkbox(value=False, label="Use MCut threshold") |
|
submit = gr.Button(value="Submit", variant="primary") |
|
|
|
with gr.Column(variant="panel"): |
|
text_output = gr.Textbox(label="Output (string)") |
|
rating_output = gr.Label(label="Rating") |
|
character_output = gr.Label(label="Characters") |
|
general_output = gr.Label(label="Tags") |
|
|
|
submit.click( |
|
ui_predict, |
|
inputs=[image, model_repo, general_thresh, general_mcut, |
|
character_thresh, character_mcut], |
|
outputs=[text_output, rating_output, character_output, general_output] |
|
) |
|
|
|
demo.queue(max_size=10) |
|
return demo |
|
|
|
app = gr.mount_gradio_app(app, create_demo(), path="/") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |