{ "cells": [ { "cell_type": "code", "execution_count": 41, "id": "761173b3", "metadata": {}, "outputs": [], "source": [ "#import libraries\n", "\n", "import argparse\n", "import json\n", "import logging\n", "from dataclasses import asdict, dataclass\n", "from os import PathLike, getenv\n", "from pathlib import Path\n", "from typing import Any, Dict, List, Optional, Tuple\n", "\n", "import numpy as np\n", "import onnxruntime as rt\n", "from huggingface_hub import snapshot_download\n", "from pandas import read_csv\n", "from PIL import Image\n", "from torch.utils.data import DataLoader, Dataset\n", "from tqdm.auto import tqdm\n", "import csv \n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 42, "id": "ffc17b6b", "metadata": {}, "outputs": [], "source": [ "# allowed extensions\n", "IMAGE_EXTENSIONS = [\".jpg\", \".jpeg\", \".png\", \".gif\", \".webp\", \".bmp\", \".tiff\", \".tif\"]" ] }, { "cell_type": "code", "execution_count": 43, "id": "d289afd3", "metadata": {}, "outputs": [], "source": [ "# model input shape\n", "IMAGE_SIZE = 448" ] }, { "cell_type": "code", "execution_count": 44, "id": "465fcb74", "metadata": {}, "outputs": [], "source": [ "# hf hub insists on putting things in the cache dir then hardlinking and unlinking\n", "# which breaks across mount points, so we override it here unless an explicit path is given in args\n", "HF_HOME = getenv(\"HF_HOME\", Path.cwd().joinpath(\".cache\"))\n", "CACHE_DIR = HF_HOME.joinpath(\"huggingface_hub\")" ] }, { "cell_type": "code", "execution_count": 45, "id": "4734bb95", "metadata": {}, "outputs": [], "source": [ "class DictJsonMixin:\n", " def asdict(self, *args, **kwargs) -> Dict[str, Any]:\n", " return asdict(self, *args, **kwargs)\n", "\n", " def asjson(self, *args, **kwargs):\n", " return json.dumps(asdict(self, *args, **kwargs))" ] }, { "cell_type": "code", "execution_count": 46, "id": "fb1d91b2", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class LabelData(DictJsonMixin):\n", " \"\"\"\n", " A class that represents label data.\n", " \"\"\"\n", " names: List[str]\n", " rating: List[np.int64]\n", " general: List[np.int64]\n", " character: List[np.int64]" ] }, { "cell_type": "code", "execution_count": 47, "id": "3b7bd9e9", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class ImageLabels(DictJsonMixin):\n", " \"\"\"\n", " A class that represents image labels.\n", " \"\"\"\n", " caption: str\n", " booru: str\n", " rating: str\n", " general: Dict[str, float]\n", " character: Dict[str, float]\n", " ratings: Dict[str, float]" ] }, { "cell_type": "code", "execution_count": 48, "id": "1a40cc56", "metadata": {}, "outputs": [], "source": [ "logging.basicConfig(level=logging.INFO)\n", "\n", "logger = logging.getLogger(__name__)\n", "logger.setLevel(logging.INFO)" ] }, { "cell_type": "code", "execution_count": 49, "id": "59cbbf48", "metadata": {}, "outputs": [], "source": [ "def get_model_repo(base_model: str = \"convnextv2\") -> str:\n", " return f\"SmilingWolf/wd-v1-4-{base_model}-tagger-v2\"" ] }, { "cell_type": "code", "execution_count": 50, "id": "5518fa1b", "metadata": {}, "outputs": [], "source": [ "def collate_fn_remove_corrupted(batch):\n", " \"\"\"Collate function that allows to remove corrupted examples in the\n", " dataloader. It expects that the dataloader returns 'None' when that occurs.\n", " The 'None's in the batch are removed.\n", " \"\"\"\n", " # Filter out all the Nones (corrupted examples)\n", " return [x for x in batch if x is not None]" ] }, { "cell_type": "code", "execution_count": 51, "id": "8481068c", "metadata": {}, "outputs": [], "source": [ "def load_labels(model_path: Path) -> LabelData:\n", " path = model_path.joinpath(\"selected_tags.csv\")\n", " df = read_csv(path)\n", "\n", " tag_data = LabelData(\n", " names=df[\"name\"].tolist(),\n", " rating=list(np.where(df[\"category\"] == 9)[0]),\n", " general=list(np.where(df[\"category\"] == 0)[0]),\n", " character=list(np.where(df[\"category\"] == 4)[0]),\n", " )\n", "\n", " return tag_data" ] }, { "cell_type": "code", "execution_count": 52, "id": "593bf897", "metadata": {}, "outputs": [], "source": [ "def preprocess_image(image: Image.Image, size_px: int = IMAGE_SIZE, upscale=True) -> Image.Image:\n", " \"\"\"\n", " Preprocess an image to be square and centered on a white background.\n", " \"\"\"\n", " # make tuple for PIL\n", " size = (size_px, size_px)\n", "\n", " # scale up or down (maintaining aspect ratio) as needed\n", " if image.width > size_px or image.height > size_px:\n", " image.thumbnail(size, Image.Resampling.LANCZOS)\n", " elif upscale is True:\n", " ratio = size_px / max(image.width, image.height)\n", " scale_to = (int(image.width * ratio), int(image.height * ratio))\n", " image = image.resize(scale_to, Image.LANCZOS)\n", "\n", " # work out where to paste the image to make it square\n", " delta_h = (size_px - image.height) // 2\n", " delta_w = (size_px - image.width) // 2\n", "\n", " # paste image onto square white canvas, centered\n", " image = image.convert(\"RGBA\")\n", " canvas = Image.new(\"RGBA\", size, (255, 255, 255))\n", " canvas.paste(image, box=(delta_w, delta_h), mask=image)\n", "\n", " # convert to 24-bit BGR for OpenCV and return\n", " canvas = canvas.convert(\"RGB\").convert(\"BGR;24\")\n", " return canvas" ] }, { "cell_type": "code", "execution_count": 53, "id": "6fd3593f", "metadata": {}, "outputs": [], "source": [ "class ImageDataset(Dataset):\n", " def __init__(self, image_paths: List[Path], size_px: int = IMAGE_SIZE, upscale: bool = True):\n", " self.size_px = size_px\n", " self.upscale = upscale\n", " self.images = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS]\n", "\n", " def __len__(self):\n", " return len(self.images)\n", "\n", " def __getitem__(self, idx):\n", " image_path: Path = self.images[idx]\n", " try:\n", " image = Image.open(image_path)\n", " image = preprocess_image(image, self.size_px, self.upscale)\n", " image = np.asarray(image)\n", " image = image.astype(np.float32)\n", " image = np.expand_dims(image, axis=0)\n", " except Exception as e:\n", " logging.exception(f\"Could not load image from {image_path}, error: {e}\")\n", " return None\n", " return image, image_path" ] }, { "cell_type": "code", "execution_count": 54, "id": "be8dc9e4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['1girl',\n", " 'long_hair',\n", " 'breasts',\n", " 'blush',\n", " 'smile',\n", " 'short_hair',\n", " 'open_mouth',\n", " 'bangs',\n", " 'blue_eyes',\n", " 'skirt',\n", " 'blonde_hair',\n", " 'large_breasts',\n", " 'brown_hair',\n", " 'shirt',\n", " 'black_hair',\n", " 'hair_ornament',\n", " 'red_eyes',\n", " 'thighhighs',\n", " 'gloves',\n", " 'long_sleeves',\n", " '1boy',\n", " 'hat',\n", " 'dress',\n", " 'bow',\n", " 'ribbon',\n", " 'navel',\n", " 'holding',\n", " '2girls',\n", " 'animal_ears',\n", " 'cleavage',\n", " 'hair_between_eyes',\n", " 'bare_shoulders',\n", " 'twintails',\n", " 'brown_eyes',\n", " 'medium_breasts',\n", " 'sitting',\n", " 'very_long_hair',\n", " 'closed_mouth',\n", " 'underwear',\n", " 'nipples',\n", " 'school_uniform',\n", " 'green_eyes',\n", " 'blue_hair',\n", " 'standing',\n", " 'purple_eyes',\n", " 'collarbone',\n", " 'panties',\n", " 'jacket',\n", " 'tail',\n", " 'swimsuit',\n", " 'hair_ribbon',\n", " 'yellow_eyes',\n", " 'white_shirt',\n", " 'ponytail',\n", " 'weapon',\n", " 'pink_hair',\n", " 'purple_hair',\n", " 'ass',\n", " 'braid',\n", " 'flower',\n", " 'ahoge',\n", " 'white_hair',\n", " 'short_sleeves',\n", " ':d',\n", " 'hetero',\n", " 'hair_bow',\n", " 'grey_hair',\n", " 'male_focus',\n", " 'heart',\n", " 'pantyhose',\n", " 'sidelocks',\n", " 'bikini',\n", " 'thighs',\n", " 'red_hair',\n", " 'multicolored_hair',\n", " 'cowboy_shot',\n", " 'sweat',\n", " 'pleated_skirt',\n", " 'hairband',\n", " 'earrings',\n", " 'small_breasts',\n", " 'boots',\n", " 'lying',\n", " 'frills',\n", " 'parted_lips',\n", " 'detached_sleeves',\n", " 'one_eye_closed',\n", " 'japanese_clothes',\n", " 'green_hair',\n", " 'multiple_boys',\n", " 'open_clothes',\n", " 'wings',\n", " 'necktie',\n", " 'horns',\n", " 'sky',\n", " 'penis',\n", " 'shoes',\n", " 'glasses',\n", " 'shorts',\n", " 'barefoot',\n", " 'teeth',\n", " 'pussy',\n", " 'serafuku',\n", " 'sleeveless',\n", " 'alternate_costume',\n", " 'choker',\n", " 'tongue',\n", " 'pointy_ears',\n", " 'black_gloves',\n", " 'socks',\n", " 'hairclip',\n", " 'elbow_gloves',\n", " 'fang',\n", " 'midriff',\n", " 'striped',\n", " 'puffy_sleeves',\n", " 'collared_shirt',\n", " 'belt',\n", " 'pants',\n", " 'sword',\n", " 'black_thighhighs',\n", " 'virtual_youtuber',\n", " 'cat_ears',\n", " 'tears',\n", " 'white_gloves',\n", " 'hand_up',\n", " 'hair_flower',\n", " '3girls',\n", " 'spread_legs',\n", " 'cum',\n", " 'hood',\n", " '2boys',\n", " 'sex',\n", " 'tongue_out',\n", " 'miniskirt',\n", " 'wide_sleeves',\n", " 'blunt_bangs',\n", " 'on_back',\n", " 'fingerless_gloves',\n", " 'bowtie',\n", " 'black_skirt',\n", " 'medium_hair',\n", " 'pink_eyes',\n", " 'armpits',\n", " 'sailor_collar',\n", " 'kimono',\n", " 'grey_background',\n", " 'necklace',\n", " 'off_shoulder',\n", " 'stomach',\n", " 'bag',\n", " 'hair_bun',\n", " 'clothes_lift',\n", " 'star_(symbol)',\n", " 'scarf',\n", " 'cape',\n", " 'nail_polish',\n", " 'black_footwear',\n", " 'holding_weapon',\n", " 'bra',\n", " 'white_dress',\n", " 'orange_hair',\n", " 'yuri',\n", " 'sweatdrop',\n", " 'armor',\n", " 'rabbit_ears',\n", " 'mole',\n", " 'white_panties',\n", " 'hair_over_one_eye',\n", " 'grin',\n", " 'huge_breasts',\n", " 'looking_at_another',\n", " ':o',\n", " 'uniform',\n", " 'black_eyes',\n", " 'apron',\n", " 'character_name',\n", " 'vest',\n", " 'black_dress',\n", " 'arm_up',\n", " 'vaginal',\n", " 'red_bow',\n", " 'high_heels',\n", " 'twin_braids',\n", " 'arms_up',\n", " 'flat_chest',\n", " 'side_ponytail',\n", " 'collar',\n", " 'bracelet',\n", " 'feet',\n", " 'covered_nipples',\n", " 'two-tone_hair',\n", " 'aqua_eyes',\n", " 'sweater',\n", " 'speech_bubble',\n", " 'white_thighhighs',\n", " 'leotard',\n", " 'open_jacket',\n", " 'official_alternate_costume',\n", " 'red_ribbon',\n", " 'tree',\n", " 'cup',\n", " 'puffy_short_sleeves',\n", " 'lips',\n", " 'blue_skirt',\n", " 'zettai_ryouiki',\n", " 'streaked_hair',\n", " 'coat',\n", " 'black_jacket',\n", " 'crop_top',\n", " 'groin',\n", " 'fingernails',\n", " 'v-shaped_eyebrows',\n", " 'cat_tail',\n", " 'neckerchief',\n", " 'orange_eyes',\n", " 'animal_ear_fluff',\n", " 'head_tilt',\n", " 'see-through',\n", " 'hand_on_hip',\n", " 'gun',\n", " 'legs',\n", " 'one-piece_swimsuit',\n", " 'sleeves_past_wrists',\n", " 'parted_bangs',\n", " 'wrist_cuffs',\n", " 'grey_eyes',\n", " 'torn_clothes',\n", " 'plaid',\n", " 'black_pantyhose',\n", " 'maid',\n", " 'symbol-shaped_pupils',\n", " 'hands_up',\n", " 'sash',\n", " 'fur_trim',\n", " 'kneehighs',\n", " 'maid_headdress',\n", " 'black_panties',\n", " 'cosplay',\n", " 'bare_arms',\n", " 'petals',\n", " 'pubic_hair',\n", " 'black_shirt',\n", " 'fox_ears',\n", " 'loli',\n", " 'short_shorts',\n", " 'ascot',\n", " 'clothing_cutout',\n", " 'completely_nude',\n", " 'dutch_angle',\n", " 'eyelashes',\n", " 'bar_censor',\n", " 'mole_under_eye',\n", " 'pokemon_(creature)',\n", " 'no_humans',\n", " 'bare_legs',\n", " 'window',\n", " 'open_shirt',\n", " 'sparkle',\n", " 'dress_shirt',\n", " 'kneeling',\n", " 'sleeveless_shirt',\n", " 'single_braid',\n", " 'v',\n", " 'black_headwear',\n", " 'strapless',\n", " '4girls',\n", " 'bell',\n", " 'hug',\n", " 'no_bra',\n", " 'saliva',\n", " 'double_bun',\n", " 'black_ribbon',\n", " 'uncensored',\n", " 'aqua_hair',\n", " 'bodysuit',\n", " 'blood',\n", " 'bed',\n", " 'hoodie',\n", " 'military_uniform',\n", " 'sideboob',\n", " 'black_bow',\n", " 'covered_navel',\n", " 'tattoo',\n", " 'gradient_hair',\n", " 'skindentation',\n", " 'neck_ribbon',\n", " 'pussy_juice',\n", " 'profile',\n", " 'makeup',\n", " 'thigh_strap',\n", " 'leaning_forward',\n", " 'multiple_views',\n", " '4koma',\n", " 'capelet',\n", " 'mask',\n", " 'muscular',\n", " 'anus',\n", " 'no_panties',\n", " 'witch_hat',\n", " 'detached_collar',\n", " 'toes',\n", " ':3',\n", " 'copyright_name',\n", " 'alternate_hairstyle',\n", " 'underboob',\n", " 'night',\n", " 'buttons',\n", " 'floating_hair',\n", " 'fruit',\n", " 'sleeveless_dress',\n", " 'depth_of_field',\n", " 'feet_out_of_frame',\n", " 'headband',\n", " 'fake_animal_ears',\n", " '^_^',\n", " 'blue_dress',\n", " 'cameltoe',\n", " 'cum_in_pussy',\n", " 'fox_tail',\n", " 'swept_bangs',\n", " 'shadow',\n", " 'black_bikini',\n", " 'red_skirt',\n", " 'nose_blush',\n", " 'bottomless',\n", " 'glowing',\n", " 'side-tie_bikini_bottom',\n", " 'rose',\n", " 'bed_sheet',\n", " 'colored_skin',\n", " 'turtleneck',\n", " 'holding_hands',\n", " 'facial_hair',\n", " 'chain',\n", " 'headgear',\n", " 'bird',\n", " 'pov',\n", " 'siblings',\n", " 'headphones',\n", " 'ocean',\n", " '6+girls',\n", " 'low_twintails',\n", " 'heterochromia',\n", " 'arm_support',\n", " 'animal',\n", " 'halterneck',\n", " 'frown',\n", " 'leaf',\n", " 'beret',\n", " 'white_headwear',\n", " 'umbrella',\n", " 'on_bed',\n", " 'one_side_up',\n", " 'embarrassed',\n", " 'thigh_boots',\n", " 'fangs',\n", " 'upper_teeth_only',\n", " 'watermark',\n", " 'from_above',\n", " 'back',\n", " 'highleg',\n", " 'blue_background',\n", " 'ass_visible_through_thighs',\n", " 'wavy_hair',\n", " 'garter_straps',\n", " 'black_choker',\n", " 'halo',\n", " 'blue_bow',\n", " 'scar',\n", " 'white_bikini',\n", " 'on_side',\n", " 'plaid_skirt',\n", " 'chair',\n", " 'transparent_background',\n", " 'wariza',\n", " 'facial_mark',\n", " 'mouth_hold',\n", " 'looking_away',\n", " 'traditional_media',\n", " 'beach',\n", " 'bandages',\n", " 'parody',\n", " 'female_pubic_hair',\n", " 'expressionless',\n", " 'brown_footwear',\n", " 'blush_stickers',\n", " 'shirt_lift',\n", " 'thick_thighs',\n", " 'no_shoes',\n", " 'holding_sword',\n", " 'hair_tubes',\n", " 'chinese_clothes',\n", " 'drill_hair',\n", " 'grabbing',\n", " 'arms_behind_back',\n", " 'soles',\n", " 'obi',\n", " 'heart-shaped_pupils',\n", " 'eating',\n", " 'clothes_pull',\n", " 'looking_down',\n", " 'phone',\n", " 'black_shorts',\n", " 'thigh_gap',\n", " 'black_pants',\n", " 'short_dress',\n", " 'topless',\n", " 'piercing',\n", " 'pantyshot',\n", " 'hair_intakes',\n", " 'eyepatch',\n", " 'border',\n", " 'skirt_lift',\n", " 'floral_print',\n", " 'stuffed_toy',\n", " 'bound',\n", " 'formal',\n", " 'playboy_bunny',\n", " 'flying_sweatdrops',\n", " 'crossed_arms',\n", " 'wavy_mouth',\n", " 'magical_girl',\n", " 'erection',\n", " 'abs',\n", " 'moon',\n", " 'half-closed_eyes',\n", " 'leg_up',\n", " 'from_below',\n", " 'red_dress',\n", " 'cleavage_cutout',\n", " 'sandals',\n", " 'table',\n", " 'happy',\n", " 'sunlight',\n", " 'oral',\n", " 'cover',\n", " 'squatting',\n", " 'single_hair_bun',\n", " 'cat',\n", " 'testicles',\n", " 'pink_background',\n", " 'sunglasses',\n", " 'scrunchie',\n", " 'white_footwear',\n", " 'dark-skinned_male',\n", " 'underwear_only',\n", " 'cum_on_body',\n", " 'trembling',\n", " 'bob_cut',\n", " 'ring',\n", " 'bdsm',\n", " 'school_swimsuit',\n", " 'mob_cap',\n", " 'wolf_ears',\n", " 'blazer',\n", " 'light_brown_hair',\n", " 'white_jacket',\n", " 'standing_on_one_leg',\n", " 'sleeping',\n", " 'thick_eyebrows',\n", " 'backpack',\n", " 'white_skirt',\n", " 'demon_girl',\n", " 'frilled_dress',\n", " 'eyes_visible_through_hair',\n", " 'breast_grab',\n", " 'cardigan',\n", " 'knee_boots',\n", " 'suspenders',\n", " 'hat_ribbon',\n", " 'crossed_legs',\n", " 'lingerie',\n", " 'stuffed_animal',\n", " 'katana',\n", " 'hood_down',\n", " ';d',\n", " '3boys',\n", " 'bat_wings',\n", " 'horse_ears',\n", " 'helmet',\n", " 'cloudy_sky',\n", " 'cellphone',\n", " 'crying',\n", " 'antenna_hair',\n", " 'own_hands_together',\n", " 'tank_top',\n", " 'bottle',\n", " 'suit',\n", " 'grass',\n", " 'outstretched_arms',\n", " 'cross',\n", " 'bug',\n", " 'holding_food',\n", " 'fire',\n", " 'frilled_skirt',\n", " 'tiara',\n", " 'aged_down',\n", " 'polka_dot',\n", " 'feathers',\n", " 'breasts_out',\n", " 'crossover',\n", " 'crown',\n", " 'high_ponytail',\n", " 'looking_up',\n", " 'black_hairband',\n", " 'bent_over',\n", " 'undressing',\n", " 'blue_shirt',\n", " 'white_bow',\n", " '5girls',\n", " 'straddling',\n", " 'light_smile',\n", " 'knife',\n", " 'pectorals',\n", " 'x_hair_ornament',\n", " 'plant',\n", " 'couple',\n", " 'denim',\n", " 'on_stomach',\n", " 'wing_collar',\n", " '>_<',\n", " 'robot',\n", " 'white_flower',\n", " 'hair_bobbles',\n", " 'fellatio',\n", " 'outstretched_arm',\n", " 'sharp_teeth',\n", " 'blue_ribbon',\n", " 'lipstick',\n", " 'tan',\n", " 'girl_on_top',\n", " 'cat_girl',\n", " 'short_twintails',\n", " 'lifted_by_self',\n", " 'bondage',\n", " 'curtains',\n", " 'white_socks',\n", " 'letterboxed',\n", " 'animal_print',\n", " 'muscular_male',\n", " 'spiked_hair',\n", " 'pointing',\n", " 'pink_bow',\n", " 'juliet_sleeves',\n", " 'monster_girl',\n", " 'sex_from_behind',\n", " 'slit_pupils',\n", " 'polearm',\n", " 'all_fours',\n", " 'blue_jacket',\n", " 'sisters',\n", " '^^^',\n", " 'frilled_sleeves',\n", " 'hand_on_own_chest',\n", " 'red_necktie',\n", " 'blue_sailor_collar',\n", " 'crescent',\n", " '?',\n", " 'staff',\n", " 'black_background',\n", " 'clenched_teeth',\n", " 'panty_pull',\n", " 'cherry_blossoms',\n", " 'head_wings',\n", " 'horse_girl',\n", " 'brooch',\n", " 'goggles',\n", " 'demon_horns',\n", " 'towel',\n", " 'blouse',\n", " 'shaded_face',\n", " 'red_flower',\n", " 'green_skirt',\n", " 'fox_girl',\n", " 'ground_vehicle',\n", " 'cover_page',\n", " 'black_bra',\n", " 'elf',\n", " 'bike_shorts',\n", " 'otoko_no_ko',\n", " 'wind',\n", " 'casual',\n", " 'black_socks',\n", " 'loafers',\n", " 't-shirt',\n", " 'motion_lines',\n", " 'shoulder_armor',\n", " 'gauntlets',\n", " 'no_pants',\n", " 'building',\n", " 'pink_panties',\n", " 'messy_hair',\n", " 'single_thighhigh',\n", " 'multiple_tails',\n", " 'kiss',\n", " 'wristband',\n", " 'group_sex',\n", " 'breast_press',\n", " 'between_breasts',\n", " 'surprised',\n", " 'striped_panties',\n", " 'hat_bow',\n", " 'gem',\n", " 'butterfly',\n", " 'red_footwear',\n", " 'red_shirt',\n", " 'sheath',\n", " 'sneakers',\n", " 'rabbit_tail',\n", " 'tassel',\n", " 'instrument',\n", " 'box',\n", " 'ear_piercing',\n", " 'drooling',\n", " 'fishnets',\n", " 'ribbon_trim',\n", " 'clenched_hand',\n", " 'sex_toy',\n", " 'red_bowtie',\n", " 'third_eye',\n", " 'skirt_set',\n", " 'child',\n", " 'hakama',\n", " 'pale_skin',\n", " 'portrait',\n", " 'musical_note',\n", " 'revealing_clothes',\n", " 'rope',\n", " 'star_(sky)',\n", " 'wet_clothes',\n", " 'steam',\n", " 'candy',\n", " 'pink_dress',\n", " 'genderswap',\n", " 'facial',\n", " 'demon_tail',\n", " 'dog_ears',\n", " 'anal',\n", " 'foreshortening',\n", " 'holding_gun',\n", " 'nature',\n", " 'covering',\n", " 'adapted_costume',\n", " 'side-tie_panties',\n", " 'black_nails',\n", " 'night_sky',\n", " 'christmas',\n", " 'breath',\n", " 'ejaculation',\n", " 'veil',\n", " 'scenery',\n", " 'armband',\n", " 'peaked_cap',\n", " 'waist_apron',\n", " 'lace_trim',\n", " 'convenient_censoring',\n", " 'white_apron',\n", " 'couch',\n", " 'arms_behind_head',\n", " 'china_dress',\n", " 'bandaid',\n", " 'holding_cup',\n", " 'black_leotard',\n", " 'male_pubic_hair',\n", " 'interlocked_fingers',\n", " 'mole_under_mouth',\n", " 'microphone',\n", " 'bridal_gauntlets',\n", " 'bara',\n", " 'strapless_dress',\n", " 'tokin_hat',\n", " 'yaoi',\n", " 'straight_hair',\n", " 'front-tie_top',\n", " 'bow_panties',\n", " 'lace',\n", " 'mecha',\n", " 'hakama_skirt',\n", " 'hand_fan',\n", " 'white_ribbon',\n", " 'glowing_eyes',\n", " 'anger_vein',\n", " '...',\n", " 'breasts_apart',\n", " 'no_headwear',\n", " 'hair_over_shoulder',\n", " 'clothes_writing',\n", " 'jingle_bell',\n", " 'baseball_cap',\n", " 'yellow_background',\n", " 'hair_flaps',\n", " 'string_bikini',\n", " 'feathered_wings',\n", " 'hooded_jacket',\n", " 'cum_on_breasts',\n", " 'bikini_top_only',\n", " 'red_headwear',\n", " 'twin_drills',\n", " 'facing_viewer',\n", " 'skin_tight',\n", " 'multiple_penises',\n", " 'semi-rimless_eyewear',\n", " 'red_nails',\n", " 'bright_pupils',\n", " 'black_necktie',\n", " 'web_address',\n", " ':<',\n", " 'angry',\n", " 'grey_shirt',\n", " 'cloak',\n", " 'eyewear_on_head',\n", " 'motor_vehicle',\n", " 'red_background',\n", " 'claws',\n", " 'side_braid',\n", " 'wolf_tail',\n", " 'pelvic_curtain',\n", " 'light_particles',\n", " 'light_purple_hair',\n", " 'multicolored_clothes',\n", " 'carrying',\n", " 'micro_bikini',\n", " 'knees_up',\n", " 'smartphone',\n", " 'corset',\n", " 'tentacles',\n", " 'index_finger_raised',\n", " 'clothing_aside',\n", " 'purple_dress',\n", " 'extra_ears',\n", " 'rifle',\n", " 'striped_thighhighs',\n", " 'white_border',\n", " 'mary_janes',\n", " 'beard',\n", " 'paizuri',\n", " 'vertical_stripes',\n", " 'red_jacket',\n", " ':p',\n", " 'red_neckerchief',\n", " 'short_hair_with_long_locks',\n", " 'scar_on_face',\n", " 'tareme',\n", " 'neck_bell',\n", " 'licking',\n", " 'furry',\n", " 'single_horn',\n", " 'strap_slip',\n", " 'finger_to_mouth',\n", " 'pom_pom_(clothes)',\n", " 'snow',\n", " 'french_braid',\n", " 'close-up',\n", " 'androgynous',\n", " '1other',\n", " 'areola_slip',\n", " 'forehead',\n", " 'puffy_nipples',\n", " 'buckle',\n", " 'horse_tail',\n", " 'two-tone_background',\n", " 'full_moon',\n", " 'eye_contact',\n", " 'pink_flower',\n", " 'tsurime',\n", " 'yellow_bow',\n", " 'gift',\n", " 'seiza',\n", " 'upskirt',\n", " 'blue_bikini',\n", " 'pink_nails',\n", " 'santa_hat',\n", " 'genderswap_(mtf)',\n", " 'lens_flare',\n", " 'skin_fang',\n", " 'spikes',\n", " 'armlet',\n", " 'hand_on_own_face',\n", " 'desk',\n", " 'between_legs',\n", " 'brown_gloves',\n", " 'side_slit',\n", " 'handgun',\n", " 'camisole',\n", " 'wading',\n", " 'faceless',\n", " 'low_ponytail',\n", " 'restrained',\n", " 'pendant',\n", " 'plate',\n", " 'dual_persona',\n", " 'masturbation',\n", " 'highleg_leotard',\n", " 'spoken_heart',\n", " 'curvy',\n", " 'green_bow',\n", " 'maid_apron',\n", " 'alcohol',\n", " 'after_sex',\n", " 'grey_skirt',\n", " 'handjob',\n", " 'sleeves_rolled_up',\n", " 'red_gloves',\n", " 'o-ring',\n", " 'heavy_breathing',\n", " 'abyssal_ship',\n", " 'eyeshadow',\n", " 'ribbed_sweater',\n", " 'drinking_glass',\n", " 'hair_scrunchie',\n", " 'cowgirl_position',\n", " 'cross-laced_footwear',\n", " 'blue_headwear',\n", " 'broom',\n", " 'ball',\n", " 'puffy_long_sleeves',\n", " 'sleeves_past_fingers',\n", " 'clenched_hands',\n", " 'hood_up',\n", " 'cropped_legs',\n", " 'floating',\n", " 'wide_hips',\n", " 'forest',\n", " 'low-tied_long_hair',\n", " 'breast_hold',\n", " 'smoke',\n", " 'zipper',\n", " 'dress_lift',\n", " 'tray',\n", " 'personification',\n", " 'headwear_removed',\n", " 'high_heel_boots',\n", " 'partially_submerged',\n", " 'headset',\n", " 'halloween',\n", " 'hair_rings',\n", " 'legs_up',\n", " 'half_updo',\n", " 'doujin_cover',\n", " 'pink_skirt',\n", " 'starry_sky',\n", " 'colored_sclera',\n", " 'pencil_skirt',\n", " 'strapless_leotard',\n", " 'single_glove',\n", " 'machinery',\n", " 'clothed_sex',\n", " 'blue_nails',\n", " 'backlighting',\n", " 'freckles',\n", " 'tearing_up',\n", " 'reflection',\n", " 'tanlines',\n", " 'fish',\n", " 'sweater_vest',\n", " 'holding_book',\n", " 'arm_behind_back',\n", " 'arm_at_side',\n", " 'santa_costume',\n", " 'large_pectorals',\n", " 'spot_color',\n", " 'flying',\n", " 'white_bra',\n", " 'asymmetrical_legwear',\n", " 'brown_background',\n", " 'panties_under_pantyhose',\n", " 'nontraditional_miko',\n", " 'red_bikini',\n", " 'happy_birthday',\n", " 'cropped_jacket',\n", " 'long_fingernails',\n", " '!',\n", " 'kemonomimi_mode',\n", " 'sailor_dress',\n", " 'clothed_female_nude_male',\n", " 'walking',\n", " 'fingering',\n", " 'science_fiction',\n", " 'rain',\n", " 'white_pantyhose',\n", " 'garter_belt',\n", " 'frilled_bikini',\n", " 'dual_wielding',\n", " '6+boys',\n", " 'pink_ribbon',\n", " 'cuffs',\n", " 'red-framed_eyewear',\n", " 'dragon_horns',\n", " 'epaulettes',\n", " 'black_wings',\n", " 'bubble',\n", " 'demon_wings',\n", " 'thong',\n", " 'legs_apart',\n", " 'teacup',\n", " 'condom',\n", " 'veins',\n", " 'crossdressing',\n", " 'ribbon-trimmed_sleeves',\n", " 'holding_phone',\n", " 'gym_uniform',\n", " 'short_ponytail',\n", " 'arm_behind_head',\n", " 'cake',\n", " 'out_of_frame',\n", " 'innertube',\n", " 'oni_horns',\n", " 'contrapposto',\n", " 'naughty_face',\n", " 'green_background',\n", " 'alternate_breast_size',\n", " 'purple_background',\n", " 'black-framed_eyewear',\n", " 'rape',\n", " 'beads',\n", " 'knee_up',\n", " 'hat_ornament',\n", " 'one-hour_drawing_challenge',\n", " 'fur_collar',\n", " 'blue_shorts',\n", " 'outside_border',\n", " 'thighband_pantyhose',\n", " 'meme',\n", " 'bowl',\n", " 'toenails',\n", " 'cumdrip',\n", " 'blue_flower',\n", " 'denim_shorts',\n", " 'curly_hair',\n", " 'track_jacket',\n", " 'black_sailor_collar',\n", " 'light_blush',\n", " 'school_bag',\n", " 'pocket',\n", " 'spread_pussy',\n", " 'toned',\n", " 'pink_shirt',\n", " 'doggystyle',\n", " 'white_sleeves',\n", " ':q',\n", " 'hand_in_own_hair',\n", " 'spoken_ellipsis',\n", " 'empty_eyes',\n", " 'purple_skirt',\n", " 'crying_with_eyes_open',\n", " 'goggles_on_head',\n", " 'green_dress',\n", " '4boys',\n", " 'bulge',\n", " 'sun_hat',\n", " 'cum_in_mouth',\n", " 'lolita_fashion',\n", " 'shiny_clothes',\n", " 'pauldrons',\n", " 'outline',\n", " 'buruma',\n", " \"hand_on_another's_head\",\n", " 'futanari',\n", " 'topless_male',\n", " 'under-rim_eyewear',\n", " 'frilled_apron',\n", " 'white_pupils',\n", " 'skull',\n", " 'jitome',\n", " 'gold_trim',\n", " 'long_legs',\n", " 'sunset',\n", " 'monster',\n", " 'frilled_shirt_collar',\n", " 'emphasis_lines',\n", " 'hands_on_hips',\n", " 'high-waist_skirt',\n", " 'new_year',\n", " 'shield',\n", " 'aged_up',\n", " 'animal_hands',\n", " 'mole_on_breast',\n", " 'spear',\n", " 'asymmetrical_hair',\n", " 'female_masturbation',\n", " 'v_arms',\n", " 'single_earring',\n", " 'running',\n", " 'dog',\n", " 'angel_wings',\n", " 'long_skirt',\n", " 'breasts_squeezed_together',\n", " 'competition_swimsuit',\n", " 'watch',\n", " 'dog_tail',\n", " 'black_belt',\n", " 'black_serafuku',\n", " 'faceless_male',\n", " 'legs_together',\n", " 'ice',\n", " 'white_skin',\n", " 'blue_footwear',\n", " 'o_o',\n", " '#ERROR!',\n", " ...]" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exclude = pd.read_csv('exclude_tags.csv')\n", "exclude\n", "undesired_tags_list = exclude[exclude['exclude'] == 1]['name'].tolist()\n", "undesired_tags_list" ] }, { "cell_type": "code", "execution_count": 55, "id": "8ed9d1f3", "metadata": {}, "outputs": [], "source": [ "class ImageLabeler:\n", " def __init__(\n", " self,\n", " model_path: Optional[PathLike] = None,\n", " general_threshold: float = 0.35,\n", " character_threshold: float = 0.35,\n", " undesired_tags: Optional[List[str]] = None,\n", " ):\n", " # save model path if provided\n", " self._model_path = Path(model_path) if model_path is not None else None\n", "\n", " # create some object attributes for convenience\n", " self.general_threshold = general_threshold\n", " self.character_threshold = character_threshold\n", " self.undesired_tags = undesired_tags or []\n", "\n", " # actually load the model\n", " logging.info(f\"Loading model from path: {self._model_path}\")\n", " self.model = rt.InferenceSession(\n", " str(model_path.joinpath(\"model.onnx\")),\n", " providers=[(\"CUDAExecutionProvider\", {}), \"CPUExecutionProvider\"],\n", " )\n", "\n", " # Get input dimensions\n", " _, self.height, self.width, _ = self.model.get_inputs()[0].shape\n", " logging.info(f\"Model loaded, input dimensions {self.height}x{self.width}\")\n", "\n", " # load labels\n", " self.labels = load_labels(self._model_path)\n", " self.labels.general = [i for i in self.labels.general if i not in undesired_tags]\n", " self.labels.character = [i for i in self.labels.character if i not in undesired_tags]\n", " logging.info(f\"Loaded labels from {self._model_path.joinpath('selected_tags.csv')}\")\n", "\n", " @property\n", " def input_size(self) -> Tuple[int, int]:\n", " return (self.height, self.width)\n", "\n", " @property\n", " def input_name(self) -> str:\n", " return self.model.get_inputs()[0].name if self.model is not None else None\n", "\n", " @property\n", " def output_name(self) -> str:\n", " return self.model.get_outputs()[0].name if self.model is not None else None\n", "\n", " def label_image(self, images: np.ndarray) -> ImageLabels:\n", " # Run the ONNX model\n", " probs = [self.model.run([self.output_name], {self.input_name: x}) for x in images]\n", " # Convert to labels\n", " results = []\n", " for prob in list(probs):\n", " labels = list(zip(self.labels.names, prob[0][0].astype(float)))\n", "\n", " # First 4 labels are actually ratings: pick one with argmax\n", " rating_labels = dict([labels[i] for i in self.labels.rating])\n", " rating = max(rating_labels, key=rating_labels.get)\n", "\n", " # General labels, pick any where prediction confidence > threshold\n", " gen_labels = [labels[i] for i in self.labels.general]\n", " gen_labels = dict([x for x in gen_labels if x[1] > self.general_threshold])\n", " gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))\n", "\n", " # Convert to a string suitable for use as a training caption\n", " caption = \", \".join([x for x in gen_labels])\n", "\n", " booru = caption.replace(\"_\", \" \").replace(\"(\", \"\\(\").replace(\")\", \"\\)\")\n", "\n", " # return output\n", " results.append(\n", " ImageLabels(\n", " caption=caption,\n", " booru=booru,\n", " rating=rating,\n", " general=gen_labels,\n", " character={}, # returning an empty dictionary for character labels\n", " ratings=rating_labels,\n", " )\n", " )\n", "\n", " return results\n", "\n", " def __call__(self, images: List[Image.Image]) -> ImageLabels:\n", " # if not a list, just label the image\n", " for x in images:\n", " yield self.label_image(x)" ] }, { "cell_type": "code", "execution_count": 56, "id": "b476a2f9", "metadata": {}, "outputs": [], "source": [ "def main(\n", " images_dir: str = \"/home/irakli/foxtagger/inputs\",\n", " base_model: str = \"convnextv2\",\n", " models_dir: str = \"/home/irakli/foxtagger/models\",\n", " force_download: bool = False,\n", " recursive: bool = True,\n", " undesired_tags: List[str] = undesired_tags_list,\n", " caption_extension: str = \".txt\",\n", " frequency_tags: bool = False,\n", " max_data_loader_n_workers: int = 4,\n", " remove_underscore: bool = True,\n", " thresh: float = 0.35,\n", " general_threshold: float = None,\n", " character_threshold: float = None,\n", " debug: bool = False,\n", "):\n", " base_model = base_model\n", " models_dir = Path(models_dir) if models_dir is not None else Path.cwd().joinpath(\"models\")\n", " images_dir = Path(images_dir)\n", " force_download = force_download or False\n", " # Specify the name of your model file\n", " model_filename = 'model.onnx'\n", "\n", " recursive = recursive\n", " undesired_tags = set(undesired_tags)\n", " caption_extension = str(caption_extension).lower()\n", " frequency_tags = frequency_tags\n", " max_data_loader_n_workers = max_data_loader_n_workers\n", "\n", " remove_underscore = remove_underscore\n", " general_threshold = general_threshold or thresh\n", " character_threshold = character_threshold or thresh\n", " debug = debug\n", "\n", " # turn base model into a repo id and model path\n", " repo_id: str = get_model_repo(base_model)\n", " model_dir = models_dir.joinpath(repo_id.split(\"/\")[-1])\n", " model_path = model_dir / model_filename # This is the path to the model file\n", " \n", "\n", " # download the model if it doesn't exist, or if force_download is True\n", " print(f\"Checking for {base_model}-based tagger in {model_dir}...\")\n", " if not model_dir.is_dir() or force_download is True:\n", " print(f\"Downloading {base_model}-based tagger from '{repo_id}'\")\n", " snapshot_download(\n", " repo_id,\n", " local_dir_use_symlinks=False,\n", " local_dir=models_dir,\n", " cache_dir=CACHE_DIR,\n", " allow_patterns=[\"*.onnx\", \"*.csv\"],\n", " )\n", " else:\n", " print(\"Found existing tagger model, skipping download.\")\n", "\n", " # instantiate the dataset\n", " print(f\"Loading images from {images_dir}...\", end=\" \")\n", " if recursive:\n", " image_paths = list(Path(images_dir).rglob(\"*.*\"))\n", " else:\n", " image_paths = list(Path(images_dir).glob(\"*.*\"))\n", " image_paths = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS]\n", " print(f\"found {len(image_paths)} images to process.\")\n", " dataset = ImageDataset(image_paths)\n", "\n", " # Create the data loader\n", " dataloader = DataLoader(\n", " dataset,\n", " batch_size=1,\n", " shuffle=False,\n", " num_workers=max_data_loader_n_workers,\n", " collate_fn=collate_fn_remove_corrupted,\n", " drop_last=False,\n", " )\n", "\n", " # Create the image labeler\n", " labeler: ImageLabeler = ImageLabeler(\n", " model_path=models_dir,\n", " character_threshold=character_threshold,\n", " general_threshold=general_threshold,\n", " undesired_tags=undesired_tags,\n", " )\n", "\n", " # object to save tag frequencies\n", " tag_freqs = {}\n", " \n", " # Specify the name of your CSV output file\n", " csv_filename = 'output.csv'\n", " \n", " with open(csv_filename, 'w', newline='') as csvfile:\n", " fieldnames = ['filename', 'tags:probabilities']\n", " writer = csv.DictWriter(csvfile, fieldnames=fieldnames)\n", "\n", " # Write the header\n", " writer.writeheader()\n", "\n", " # iterate\n", " for batch in tqdm(dataloader, ncols=100):\n", " images = [x[0] for x in batch]\n", " paths = [x[1] for x in batch]\n", "\n", " # label the images\n", " batch_labels = labeler.label_image(np.asarray(images))\n", "\n", " for image_labels, image_path in zip(batch_labels, paths):\n", " # save the labels\n", " caption = image_labels.caption\n", " if remove_underscore is True:\n", " caption = caption.replace(\"_\", \" \")\n", " \n", " # filter out undesired tags\n", " tags = caption.split(\", \")\n", " tags = [tag for tag in tags if tag not in undesired_tags]\n", " caption = \", \".join(tags)\n", " \n", " # Get the relative path of the image file\n", " relative_path = Path(image_path).relative_to(images_dir)\n", " Path(image_path).with_suffix(caption_extension).write_text(caption + \"\\n\", encoding=\"utf-8\")\n", " \n", " # Write the filename, tag and probability to the CSV file in a single row\n", " general_tags_probs = ', '.join([f\"{tag}:{prob}\" for tag, prob in image_labels.general.items() if tag not in undesired_tags])\n", " writer.writerow({'filename': relative_path, 'tags:probabilities': general_tags_probs})\n", "\n", "\n", " # save the tag frequencies\n", " if frequency_tags is True:\n", " for tag in tags: # here we use filtered tags\n", " if tag not in tag_freqs:\n", " tag_freqs[tag] = 0\n", " tag_freqs[tag] += 1\n", "\n", " # debug\n", " if debug is True:\n", " print(\n", " \"\\n\".join([\n", " f\"{image_path}:\",\n", " f\" Character tags: {image_labels.character}\",\n", " f\" General tags: {image_labels.general}\",\n", " ])\n", " )\n", "\n", " if frequency_tags:\n", " sorted_tags = sorted(tag_freqs.items(), key=lambda x: x[1], reverse=True)\n", " print(\"\\nTag frequencies:\")\n", " for tag, freq in sorted_tags:\n", " print(f\"{tag}: {freq}\")\n", "\n", " print(\"done!\")" ] }, { "cell_type": "code", "execution_count": 57, "id": "e0a42c4c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checking for convnextv2-based tagger in /home/irakli/foxtagger/models/wd-v1-4-convnextv2-tagger-v2...\n", "Found existing tagger model, skipping download.\n", "Loading images from /home/irakli/foxtagger/inputs... " ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Loading model from path: /home/irakli/foxtagger/models\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "found 10856 images to process.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-05-16 00:58:07.823512985 [W:onnxruntime:Default, onnxruntime_pybind_state.cc:541 CreateExecutionProviderInstance] Failed to create CUDAExecutionProvider. Please reference https://onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html#requirements to ensure all dependencies are met.\n", "INFO:root:Model loaded, input dimensions 448x448\n", "INFO:root:Loaded labels from /home/irakli/foxtagger/models/selected_tags.csv\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a50ccecd857a49f79ac2fcd11cc8c8a0", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/10856 [00:00