#!/usr/bin/env python from __future__ import annotations import gradio as gr import huggingface_hub import numpy as np import onnxruntime as rt import pandas as pd from PIL import Image EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" MODEL_FILENAME = "model.onnx" LABEL_FILENAME = "selected_tags.csv" def load_labels(dataframe) -> list[str]: name_series = dataframe["name"] tag_names = name_series.tolist() rating_indexes = list(np.where(dataframe["category"] == 9)[0]) general_indexes = list(np.where(dataframe["category"] == 0)[0]) character_indexes = list(np.where(dataframe["category"] == 4)[0]) return tag_names, rating_indexes, general_indexes, character_indexes class Predictor: def __init__(self): self.model_target_size = None self.load_model(EVA02_LARGE_MODEL_DSV3_REPO) def download_model(self, model_repo): csv_path = huggingface_hub.hf_hub_download( model_repo, LABEL_FILENAME, ) model_path = huggingface_hub.hf_hub_download( model_repo, MODEL_FILENAME, ) return csv_path, model_path def load_model(self, model_repo): csv_path, model_path = self.download_model(model_repo) tags_df = pd.read_csv(csv_path) sep_tags = load_labels(tags_df) self.tag_names = sep_tags[0] self.rating_indexes = sep_tags[1] self.general_indexes = sep_tags[2] self.character_indexes = sep_tags[3] model = rt.InferenceSession(model_path) _, height, width, _ = model.get_inputs()[0].shape self.model_target_size = height self.model = model def prepare_image(self, image): target_size = self.model_target_size canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") # Pad image to square image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) padded_image.paste(image, (pad_left, pad_top)) # Resize if max_dim != target_size: padded_image = padded_image.resize( (target_size, target_size), Image.BICUBIC, ) # Convert to numpy array image_array = np.asarray(padded_image, dtype=np.float32) # Convert PIL-native RGB to BGR image_array = image_array[:, :, ::-1] return np.expand_dims(image_array, axis=0) def predict(self, image, general_thresh): image = self.prepare_image(image) input_name = self.model.get_inputs()[0].name label_name = self.model.get_outputs()[0].name preds = self.model.run([label_name], {input_name: image})[0] labels = list(zip(self.tag_names, preds[0].astype(float))) # First 4 labels are actually ratings: pick one with argmax ratings_names = [labels[i] for i in self.rating_indexes] ratings_names = dict(ratings_names) ratings_names = sorted( ratings_names.items(), key=lambda x: x[1], reverse=True, ) # Then we have general tags: pick any where prediction confidence > threshold general_names = [labels[i] for i in self.general_indexes] general_res = [x for x in general_names if x[1] > general_thresh] general_res = dict(general_res) ratings = "rating:" + ratings_names[0][0] if ratings_names[0][0] == "general": ratings = "rating:safe" general_res[ratings] = ratings_names[0][1] general_res = sorted( general_res.items(), key=lambda x: x[1], reverse=True, ) return dict(general_res) predictor = Predictor() def genTag(image: PIL.Image.Image, score_threshold: float): return predictor.predict(image, score_threshold)