from __future__ import annotations import functools import io import urllib from typing import Tuple, List, Any import huggingface_hub import onnxruntime as rt import pandas as pd import numpy as np import PIL.Image import requests import dbimutils import piexif import piexif.helper from urllib.request import urlopen import model HF_TOKEN = "" SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" def change_model(model_name): global loaded_models if model_name == "SwinV2": model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME) elif model_name == "ConvNext": model = load_model(CONV_MODEL_REPO, MODEL_FILENAME) elif model_name == "ConvNextV2": model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME) elif model_name == "ViT": model = load_model(VIT_MODEL_REPO, MODEL_FILENAME) loaded_models[model_name] = model return loaded_models[model_name] def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession: path = huggingface_hub.hf_hub_download( model_repo, model_filename, use_auth_token=HF_TOKEN ) model = rt.InferenceSession(path) return model def load_labels() -> tuple[list[Any], list[Any], list[Any], list[Any]]: path = huggingface_hub.hf_hub_download( CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN ) df = pd.read_csv(path) tag_names = df["name"].tolist() rating_indexes = list(np.where(df["category"] == 9)[0]) general_indexes = list(np.where(df["category"] == 0)[0]) character_indexes = list(np.where(df["category"] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes def predict( image: PIL.Image.Image, model_name: str, general_threshold: float, character_threshold: float, tag_names: list[str], rating_indexes: list[np.int64], general_indexes: list[np.int64], character_indexes: list[np.int64], ): global loaded_models if isinstance(image, str): rawimage = dbimutils.read_img_from_url(image) elif isinstance(image, PIL.Image.Image): rawimage = image else: raise Exception("Invalid image type") image = rawimage model = loaded_models[model_name] if model is None: model = change_model(model_name) _, height, width, _ = model.get_inputs()[0].shape # Alpha to white image = image.convert("RGBA") new_image = PIL.Image.new("RGBA", image.size, "WHITE") new_image.paste(image, mask=image) image = new_image.convert("RGB") image = np.asarray(image) # PIL RGB to OpenCV BGR image = image[:, :, ::-1] image = dbimutils.make_square(image, height) image = dbimutils.smart_resize(image, height) image = image.astype(np.float32) image = np.expand_dims(image, 0) input_name = model.get_inputs()[0].name label_name = model.get_outputs()[0].name probs = model.run([label_name], {input_name: image})[0] labels = list(zip(tag_names, probs[0].astype(float))) # First 4 labels are actually ratings: pick one with argmax ratings_names = [labels[i] for i in rating_indexes] rating = dict(ratings_names) # Then we have general tags: pick any where prediction confidence > threshold general_names = [labels[i] for i in general_indexes] general_res = [x for x in general_names if x[1] > general_threshold] general_res = dict(general_res) # Everything else is characters: pick any where prediction confidence > threshold character_names = [labels[i] for i in character_indexes] character_res = [x for x in character_names if x[1] > character_threshold] character_res = dict(character_res) b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True)) a = ( ", ".join(list(b.keys())) .replace("_", " ") .replace("(", "\(") .replace(")", "\)") ) c = ", ".join(list(b.keys())) items = rawimage.info geninfo = "" if "exif" in rawimage.info: exif = piexif.load(rawimage.info["exif"]) exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"") try: exif_comment = piexif.helper.UserComment.load(exif_comment) except ValueError: exif_comment = exif_comment.decode("utf8", errors="ignore") items["exif comment"] = exif_comment geninfo = exif_comment for field in [ "jfif", "jfif_version", "jfif_unit", "jfif_density", "dpi", "exif", "loop", "background", "timestamp", "duration", ]: items.pop(field, None) geninfo = items.get("parameters", geninfo) for key, text in items.items(): print(key) print(text) print("geninfo", geninfo) print("a", a) print("c", c) print("rating", rating) print("character_res", character_res) print("general_res", general_res) character_res = list(filter(lambda x: x['confidence'] > 0.4, [{'tag': tag, 'confidence': score} for tag, score in character_res.items()])) general_res = list(filter(lambda x: x['confidence'] > 0.4, [{'tag': tag, 'confidence': score} for tag, score in general_res.items()])) return {'a': a, 'c': c, 'rating': rating, 'character_res': character_res, 'general_res': general_res} def label_img( image: PIL.Image.Image | str, model: str, # model: (["SwinV2", "ConvNext", "ConvNextV2", "ViT"], value="ConvNextV2", label="Model"), l_score_general_threshold: float, l_score_character_threshold: float, ): if isinstance(image, str) and image.startswith("http"): image = dbimutils.read_img_from_url(image) global loaded_models loaded_models = {"SwinV2": None, "ConvNext": None, "ConvNextV2": None, "ViT": None} change_model("ConvNextV2") tag_names, rating_indexes, general_indexes, character_indexes = load_labels() func = functools.partial( predict, tag_names=tag_names, rating_indexes=rating_indexes, general_indexes=general_indexes, character_indexes=character_indexes, ) return func( image=image, model_name=model, general_threshold=l_score_general_threshold, character_threshold=l_score_character_threshold, ) def write_image_tag(img_id: int, is_valid: bool, tags: List[model.ImageTag], callback_url: str): model.ImageScanCallbackRequest(img_id=img_id, is_valid=is_valid, tags=tags) if __name__ == "__main__": score_slider_step = 0.05 score_general_threshold = 0.35 score_character_threshold = 0.85 ret = label_img( image='https://pub-9747017e9ec54620bfbe2385f14fe4d7.r2.dev/cnGirlYcy_v10_people_network_nannansleep/cnGirlYcy_v10_people_network_nannansleep_r_1679670778_0.png', model="SwinV2", l_score_general_threshold=score_general_threshold, l_score_character_threshold=score_character_threshold, ) print(ret)