|  | from __future__ import annotations | 
					
						
						|  |  | 
					
						
						|  | import argparse | 
					
						
						|  | import functools | 
					
						
						|  | import html | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  | import gradio as gr | 
					
						
						|  | import huggingface_hub | 
					
						
						|  | import numpy as np | 
					
						
						|  | import onnxruntime as rt | 
					
						
						|  | import pandas as pd | 
					
						
						|  | import piexif | 
					
						
						|  | import piexif.helper | 
					
						
						|  | import PIL.Image | 
					
						
						|  |  | 
					
						
						|  | from Utils import dbimutils | 
					
						
						|  |  | 
					
						
						|  | TITLE = "WaifuDiffusion v1.4 Tags" | 
					
						
						|  | DESCRIPTION = """ | 
					
						
						|  | Demo for: | 
					
						
						|  | - [SmilingWolf/wd-v1-4-moat-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-moat-tagger-v2) | 
					
						
						|  | - [SmilingWolf/wd-v1-4-swinv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2) | 
					
						
						|  | - [SmilingWolf/wd-v1-4-convnext-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2) | 
					
						
						|  | - [SmilingWolf/wd-v1-4-convnextv2-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-convnextv2-tagger-v2) | 
					
						
						|  | - [SmilingWolf/wd-v1-4-vit-tagger-v2](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2) | 
					
						
						|  |  | 
					
						
						|  | Includes "ready to copy" prompt and a prompt analyzer. | 
					
						
						|  |  | 
					
						
						|  | Modified from [NoCrypt/DeepDanbooru_string](https://huggingface.co/spaces/NoCrypt/DeepDanbooru_string) | 
					
						
						|  | Modified from [hysts/DeepDanbooru](https://huggingface.co/spaces/hysts/DeepDanbooru) | 
					
						
						|  |  | 
					
						
						|  | PNG Info code forked from [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) | 
					
						
						|  |  | 
					
						
						|  | Example image by [γ»γβββ](https://www.pixiv.net/en/users/43565085) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | HF_TOKEN = os.environ["HF_TOKEN"] | 
					
						
						|  | MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" | 
					
						
						|  | SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" | 
					
						
						|  | CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" | 
					
						
						|  | CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" | 
					
						
						|  | VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" | 
					
						
						|  | MODEL_FILENAME = "model.onnx" | 
					
						
						|  | LABEL_FILENAME = "selected_tags.csv" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def parse_args() -> argparse.Namespace: | 
					
						
						|  | parser = argparse.ArgumentParser() | 
					
						
						|  | parser.add_argument("--score-slider-step", type=float, default=0.05) | 
					
						
						|  | parser.add_argument("--score-general-threshold", type=float, default=0.35) | 
					
						
						|  | parser.add_argument("--score-character-threshold", type=float, default=0.85) | 
					
						
						|  | parser.add_argument("--share", action="store_true") | 
					
						
						|  | return parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession: | 
					
						
						|  | path = huggingface_hub.hf_hub_download( | 
					
						
						|  | model_repo, model_filename, use_auth_token=HF_TOKEN | 
					
						
						|  | ) | 
					
						
						|  | model = rt.InferenceSession(path) | 
					
						
						|  | return model | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def change_model(model_name): | 
					
						
						|  | global loaded_models | 
					
						
						|  |  | 
					
						
						|  | if model_name == "MOAT": | 
					
						
						|  | model = load_model(MOAT_MODEL_REPO, MODEL_FILENAME) | 
					
						
						|  | elif model_name == "SwinV2": | 
					
						
						|  | model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME) | 
					
						
						|  | elif model_name == "ConvNext": | 
					
						
						|  | model = load_model(CONV_MODEL_REPO, MODEL_FILENAME) | 
					
						
						|  | elif model_name == "ConvNextV2": | 
					
						
						|  | model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME) | 
					
						
						|  | elif model_name == "ViT": | 
					
						
						|  | model = load_model(VIT_MODEL_REPO, MODEL_FILENAME) | 
					
						
						|  |  | 
					
						
						|  | loaded_models[model_name] = model | 
					
						
						|  | return loaded_models[model_name] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_labels() -> list[str]: | 
					
						
						|  | path = huggingface_hub.hf_hub_download( | 
					
						
						|  | MOAT_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN | 
					
						
						|  | ) | 
					
						
						|  | df = pd.read_csv(path) | 
					
						
						|  |  | 
					
						
						|  | tag_names = df["name"].tolist() | 
					
						
						|  | rating_indexes = list(np.where(df["category"] == 9)[0]) | 
					
						
						|  | general_indexes = list(np.where(df["category"] == 0)[0]) | 
					
						
						|  | character_indexes = list(np.where(df["category"] == 4)[0]) | 
					
						
						|  | return tag_names, rating_indexes, general_indexes, character_indexes | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def plaintext_to_html(text): | 
					
						
						|  | text = ( | 
					
						
						|  | "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split("\n")]) + "</p>" | 
					
						
						|  | ) | 
					
						
						|  | return text | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def predict( | 
					
						
						|  | image: PIL.Image.Image, | 
					
						
						|  | model_name: str, | 
					
						
						|  | general_threshold: float, | 
					
						
						|  | character_threshold: float, | 
					
						
						|  | tag_names: list[str], | 
					
						
						|  | rating_indexes: list[np.int64], | 
					
						
						|  | general_indexes: list[np.int64], | 
					
						
						|  | character_indexes: list[np.int64], | 
					
						
						|  | ): | 
					
						
						|  | global loaded_models | 
					
						
						|  |  | 
					
						
						|  | rawimage = image | 
					
						
						|  |  | 
					
						
						|  | model = loaded_models[model_name] | 
					
						
						|  | if model is None: | 
					
						
						|  | model = change_model(model_name) | 
					
						
						|  |  | 
					
						
						|  | _, height, width, _ = model.get_inputs()[0].shape | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image = image.convert("RGBA") | 
					
						
						|  | new_image = PIL.Image.new("RGBA", image.size, "WHITE") | 
					
						
						|  | new_image.paste(image, mask=image) | 
					
						
						|  | image = new_image.convert("RGB") | 
					
						
						|  | image = np.asarray(image) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image = image[:, :, ::-1] | 
					
						
						|  |  | 
					
						
						|  | image = dbimutils.make_square(image, height) | 
					
						
						|  | image = dbimutils.smart_resize(image, height) | 
					
						
						|  | image = image.astype(np.float32) | 
					
						
						|  | image = np.expand_dims(image, 0) | 
					
						
						|  |  | 
					
						
						|  | input_name = model.get_inputs()[0].name | 
					
						
						|  | label_name = model.get_outputs()[0].name | 
					
						
						|  | probs = model.run([label_name], {input_name: image})[0] | 
					
						
						|  |  | 
					
						
						|  | labels = list(zip(tag_names, probs[0].astype(float))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ratings_names = [labels[i] for i in rating_indexes] | 
					
						
						|  | rating = dict(ratings_names) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | general_names = [labels[i] for i in general_indexes] | 
					
						
						|  | general_res = [x for x in general_names if x[1] > general_threshold] | 
					
						
						|  | general_res = dict(general_res) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | character_names = [labels[i] for i in character_indexes] | 
					
						
						|  | character_res = [x for x in character_names if x[1] > character_threshold] | 
					
						
						|  | character_res = dict(character_res) | 
					
						
						|  |  | 
					
						
						|  | b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True)) | 
					
						
						|  | a = ( | 
					
						
						|  | ", ".join(list(b.keys())) | 
					
						
						|  | .replace("_", " ") | 
					
						
						|  | .replace("(", "\(") | 
					
						
						|  | .replace(")", "\)") | 
					
						
						|  | ) | 
					
						
						|  | c = ", ".join(list(b.keys())) | 
					
						
						|  |  | 
					
						
						|  | items = rawimage.info | 
					
						
						|  | geninfo = "" | 
					
						
						|  |  | 
					
						
						|  | if "exif" in rawimage.info: | 
					
						
						|  | exif = piexif.load(rawimage.info["exif"]) | 
					
						
						|  | exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"") | 
					
						
						|  | try: | 
					
						
						|  | exif_comment = piexif.helper.UserComment.load(exif_comment) | 
					
						
						|  | except ValueError: | 
					
						
						|  | exif_comment = exif_comment.decode("utf8", errors="ignore") | 
					
						
						|  |  | 
					
						
						|  | items["exif comment"] = exif_comment | 
					
						
						|  | geninfo = exif_comment | 
					
						
						|  |  | 
					
						
						|  | for field in [ | 
					
						
						|  | "jfif", | 
					
						
						|  | "jfif_version", | 
					
						
						|  | "jfif_unit", | 
					
						
						|  | "jfif_density", | 
					
						
						|  | "dpi", | 
					
						
						|  | "exif", | 
					
						
						|  | "loop", | 
					
						
						|  | "background", | 
					
						
						|  | "timestamp", | 
					
						
						|  | "duration", | 
					
						
						|  | ]: | 
					
						
						|  | items.pop(field, None) | 
					
						
						|  |  | 
					
						
						|  | geninfo = items.get("parameters", geninfo) | 
					
						
						|  |  | 
					
						
						|  | info = f""" | 
					
						
						|  | <p><h4>PNG Info</h4></p> | 
					
						
						|  | """ | 
					
						
						|  | for key, text in items.items(): | 
					
						
						|  | info += ( | 
					
						
						|  | f""" | 
					
						
						|  | <div> | 
					
						
						|  | <p><b>{plaintext_to_html(str(key))}</b></p> | 
					
						
						|  | <p>{plaintext_to_html(str(text))}</p> | 
					
						
						|  | </div> | 
					
						
						|  | """.strip() | 
					
						
						|  | + "\n" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if len(info) == 0: | 
					
						
						|  | message = "Nothing found in the image." | 
					
						
						|  | info = f"<div><p>{message}<p></div>" | 
					
						
						|  |  | 
					
						
						|  | return (a, c, rating, character_res, general_res, info) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  | global loaded_models | 
					
						
						|  | loaded_models = { | 
					
						
						|  | "MOAT": None, | 
					
						
						|  | "SwinV2": None, | 
					
						
						|  | "ConvNext": None, | 
					
						
						|  | "ConvNextV2": None, | 
					
						
						|  | "ViT": None, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | args = parse_args() | 
					
						
						|  |  | 
					
						
						|  | change_model("MOAT") | 
					
						
						|  |  | 
					
						
						|  | tag_names, rating_indexes, general_indexes, character_indexes = load_labels() | 
					
						
						|  |  | 
					
						
						|  | func = functools.partial( | 
					
						
						|  | predict, | 
					
						
						|  | tag_names=tag_names, | 
					
						
						|  | rating_indexes=rating_indexes, | 
					
						
						|  | general_indexes=general_indexes, | 
					
						
						|  | character_indexes=character_indexes, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.Interface( | 
					
						
						|  | fn=func, | 
					
						
						|  | inputs=[ | 
					
						
						|  | gr.Image(type="pil", label="Input"), | 
					
						
						|  | gr.Radio( | 
					
						
						|  | ["MOAT", "SwinV2", "ConvNext", "ConvNextV2", "ViT"], | 
					
						
						|  | value="MOAT", | 
					
						
						|  | label="Model", | 
					
						
						|  | ), | 
					
						
						|  | gr.Slider( | 
					
						
						|  | 0, | 
					
						
						|  | 1, | 
					
						
						|  | step=args.score_slider_step, | 
					
						
						|  | value=args.score_general_threshold, | 
					
						
						|  | label="General Tags Threshold", | 
					
						
						|  | ), | 
					
						
						|  | gr.Slider( | 
					
						
						|  | 0, | 
					
						
						|  | 1, | 
					
						
						|  | step=args.score_slider_step, | 
					
						
						|  | value=args.score_character_threshold, | 
					
						
						|  | label="Character Tags Threshold", | 
					
						
						|  | ), | 
					
						
						|  | ], | 
					
						
						|  | outputs=[ | 
					
						
						|  | gr.Textbox(label="Output (string)"), | 
					
						
						|  | gr.Textbox(label="Output (raw string)"), | 
					
						
						|  | gr.Label(label="Rating"), | 
					
						
						|  | gr.Label(label="Output (characters)"), | 
					
						
						|  | gr.Label(label="Output (tags)"), | 
					
						
						|  | gr.HTML(), | 
					
						
						|  | ], | 
					
						
						|  | examples=[["power.jpg", "MOAT", 0.35, 0.85]], | 
					
						
						|  | title=TITLE, | 
					
						
						|  | description=DESCRIPTION, | 
					
						
						|  | allow_flagging="never", | 
					
						
						|  | ).launch( | 
					
						
						|  | enable_queue=True, | 
					
						
						|  | share=args.share, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main() | 
					
						
						|  |  |