#!/usr/bin/env python from dataclasses import dataclass from pathlib import Path from typing import Optional, Dict import numpy as np import pandas as pd import timm import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import HfHubHTTPError from PIL import Image from timm.data import create_transform, resolve_data_config from torch import Tensor, nn from torch.nn import functional as F MODEL_REPO_MAP = { "vit": "SmilingWolf/wd-vit-tagger-v3", "swinv2": "SmilingWolf/wd-swinv2-tagger-v3", "convnext": "SmilingWolf/wd-convnext-tagger-v3", } @dataclass class LabelData: names: list[str] rating: list[np.int64] general: list[np.int64] character: list[np.int64] def pil_ensure_rgb(image: Image.Image) -> Image.Image: if image.mode not in ["RGB", "RGBA"]: image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def pil_pad_square(image: Image.Image) -> Image.Image: w, h = image.size px = max(image.size) canvas = Image.new("RGB", (px, px), (255, 255, 255)) canvas.paste(image, ((px - w) // 2, (px - h) // 2)) return canvas def load_labels_hf(repo_id: str, revision: Optional[str] = None, token: Optional[str] = None) -> LabelData: try: csv_path = hf_hub_download( repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token ) csv_path = Path(csv_path).resolve() except HfHubHTTPError as e: raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) tag_data = LabelData( names=df["name"].tolist(), rating=list(np.where(df["category"] == 9)[0]), general=list(np.where(df["category"] == 0)[0]), character=list(np.where(df["category"] == 4)[0]), ) return tag_data def get_tags(probs: Tensor, labels: LabelData, gen_threshold: float, char_threshold: float): probs = list(zip(labels.names, probs.numpy())) rating_labels = dict([probs[i] for i in labels.rating]) gen_labels = [probs[i] for i in labels.general] gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) char_labels = [probs[i] for i in labels.character] char_labels = dict([x for x in char_labels if x[1] > char_threshold]) char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) combined_names = [x for x in gen_labels] combined_names.extend([x for x in char_labels]) caption = ", ".join(combined_names) return { 'caption': caption, 'rating_labels': rating_labels, 'character_labels': char_labels, 'general_labels': gen_labels } def get_image_tags( image_path: str | Path, model_name: str = "vit", gen_threshold: float = 0.35, char_threshold: float = 0.75, device: Optional[str] = None ) -> Dict: """ Process a single image and return its tags. Args: image_path: Path to the image file model_name: Model to use ('vit', 'swinv2', or 'convnext') gen_threshold: Threshold for general tags char_threshold: Threshold for character tags device: Device to use ('cuda', 'cpu', or None for auto-detection) Returns: Dictionary containing: - caption: Combined tags as a comma-separated string - rating_labels: Dictionary of rating labels and their confidence scores - character_labels: Dictionary of character tags and their confidence scores - general_labels: Dictionary of general tags and their confidence scores """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) repo_id = MODEL_REPO_MAP.get(model_name) if repo_id is None: raise ValueError(f"Model name '{model_name}' not recognized. Available models: {list(MODEL_REPO_MAP.keys())}") # Load model model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval() state_dict = timm.models.load_state_dict_from_hf(repo_id) model.load_state_dict(state_dict) model.to(device) # Load labels and create transform labels = load_labels_hf(repo_id=repo_id) transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model)) # Process image img_input: Image.Image = Image.open(image_path) img_input = pil_ensure_rgb(img_input) img_input = pil_pad_square(img_input) inputs: Tensor = transform(img_input).unsqueeze(0) inputs = inputs[:, [2, 1, 0]] # RGB to BGR with torch.inference_mode(): inputs = inputs.to(device) outputs = model.forward(inputs) outputs = F.sigmoid(outputs) outputs = outputs.to("cpu") return get_tags( probs=outputs.squeeze(0), labels=labels, gen_threshold=gen_threshold, char_threshold=char_threshold, ) if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python wdv3_single.py ") sys.exit(1) result = get_image_tags(sys.argv[1]) print(result['caption'])