In [41]:
#import libraries

import argparse
import json
import logging
from dataclasses import asdict, dataclass
from os import PathLike, getenv
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import onnxruntime as rt
from huggingface_hub import snapshot_download
from pandas import read_csv
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
import csv 
import pandas as pd

In [42]:
# allowed extensions
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"]

In [43]:
# model input shape
IMAGE_SIZE = 448

In [44]:
# hf hub insists on putting things in the cache dir then hardlinking and unlinking
# which breaks across mount points, so we override it here unless an explicit path is given in args
HF_HOME = getenv("HF_HOME", Path.cwd().joinpath(".cache"))
CACHE_DIR = HF_HOME.joinpath("huggingface_hub")

In [45]:
class DictJsonMixin:
    def asdict(self, *args, **kwargs) -> Dict[str, Any]:
        return asdict(self, *args, **kwargs)

    def asjson(self, *args, **kwargs):
        return json.dumps(asdict(self, *args, **kwargs))

In [46]:
@dataclass
class LabelData(DictJsonMixin):
    """
    A class that represents label data.
    """
    names: List[str]
    rating: List[np.int64]
    general: List[np.int64]
    character: List[np.int64]

In [47]:
@dataclass
class ImageLabels(DictJsonMixin):
    """
    A class that represents image labels.
    """
    caption: str
    booru: str
    rating: str
    general: Dict[str, float]
    character: Dict[str, float]
    ratings: Dict[str, float]

In [48]:
logging.basicConfig(level=logging.INFO)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

In [49]:
def get_model_repo(base_model: str = "convnextv2") -> str:
    return f"SmilingWolf/wd-v1-4-{base_model}-tagger-v2"

In [50]:
def collate_fn_remove_corrupted(batch):
    """Collate function that allows to remove corrupted examples in the
    dataloader. It expects that the dataloader returns 'None' when that occurs.
    The 'None's in the batch are removed.
    """
    # Filter out all the Nones (corrupted examples)
    return [x for x in batch if x is not None]

In [51]:
def load_labels(model_path: Path) -> LabelData:
    path = model_path.joinpath("selected_tags.csv")
    df = read_csv(path)

    tag_data = LabelData(
        names=df["name"].tolist(),
        rating=list(np.where(df["category"] == 9)[0]),
        general=list(np.where(df["category"] == 0)[0]),
        character=list(np.where(df["category"] == 4)[0]),
    )

    return tag_data

In [52]:
def preprocess_image(image: Image.Image, size_px: int = IMAGE_SIZE, upscale=True) -> Image.Image:
    """
    Preprocess an image to be square and centered on a white background.
    """
    # make tuple for PIL
    size = (size_px, size_px)

    # scale up or down (maintaining aspect ratio) as needed
    if image.width > size_px or image.height > size_px:
        image.thumbnail(size, Image.Resampling.LANCZOS)
    elif upscale is True:
        ratio = size_px / max(image.width, image.height)
        scale_to = (int(image.width * ratio), int(image.height * ratio))
        image = image.resize(scale_to, Image.LANCZOS)

    # work out where to paste the image to make it square
    delta_h = (size_px - image.height) // 2
    delta_w = (size_px - image.width) // 2

    # paste image onto square white canvas, centered
    image = image.convert("RGBA")
    canvas = Image.new("RGBA", size, (255, 255, 255))
    canvas.paste(image, box=(delta_w, delta_h), mask=image)

    # convert to 24-bit BGR for OpenCV and return
    canvas = canvas.convert("RGB").convert("BGR;24")
    return canvas

In [53]:
class ImageDataset(Dataset):
    def __init__(self, image_paths: List[Path], size_px: int = IMAGE_SIZE, upscale: bool = True):
        self.size_px = size_px
        self.upscale = upscale
        self.images = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path: Path = self.images[idx]
        try:
            image = Image.open(image_path)
            image = preprocess_image(image, self.size_px, self.upscale)
            image = np.asarray(image)
            image = image.astype(np.float32)
            image = np.expand_dims(image, axis=0)
        except Exception as e:
            logging.exception(f"Could not load image from {image_path}, error: {e}")
            return None
        return image, image_path

In [54]:
exclude = pd.read_csv('exclude_tags.csv')
exclude
undesired_tags_list = exclude[exclude['exclude'] == 1]['name'].tolist()
undesired_tags_list

['1girl',
 'long_hair',
 'breasts',
 'blush',
 'smile',
 'short_hair',
 'open_mouth',
 'bangs',
 'blue_eyes',
 'skirt',
 'blonde_hair',
 'large_breasts',
 'brown_hair',
 'shirt',
 'black_hair',
 'hair_ornament',
 'red_eyes',
 'thighhighs',
 'gloves',
 'long_sleeves',
 '1boy',
 'hat',
 'dress',
 'bow',
 'ribbon',
 'navel',
 'holding',
 '2girls',
 'animal_ears',
 'cleavage',
 'hair_between_eyes',
 'bare_shoulders',
 'twintails',
 'brown_eyes',
 'medium_breasts',
 'sitting',
 'very_long_hair',
 'closed_mouth',
 'underwear',
 'nipples',
 'school_uniform',
 'green_eyes',
 'blue_hair',
 'standing',
 'purple_eyes',
 'collarbone',
 'panties',
 'jacket',
 'tail',
 'swimsuit',
 'hair_ribbon',
 'yellow_eyes',
 'white_shirt',
 'ponytail',
 'weapon',
 'pink_hair',
 'purple_hair',
 'ass',
 'braid',
 'flower',
 'ahoge',
 'white_hair',
 'short_sleeves',
 ':d',
 'hetero',
 'hair_bow',
 'grey_hair',
 'male_focus',
 'heart',
 'pantyhose',
 'sidelocks',
 'bikini',
 'thighs',
 'red_hair',
 'multicolored_ha

In [55]:
class ImageLabeler:
    def __init__(
        self,
        model_path: Optional[PathLike] = None,
        general_threshold: float = 0.35,
        character_threshold: float = 0.35,
        undesired_tags: Optional[List[str]] = None,
    ):
        # save model path if provided
        self._model_path = Path(model_path) if model_path is not None else None

        # create some object attributes for convenience
        self.general_threshold = general_threshold
        self.character_threshold = character_threshold
        self.undesired_tags = undesired_tags or []

        # actually load the model
        logging.info(f"Loading model from path: {self._model_path}")
        self.model = rt.InferenceSession(
            str(model_path.joinpath("model.onnx")),
            providers=[("CUDAExecutionProvider", {}), "CPUExecutionProvider"],
        )

        # Get input dimensions
        _, self.height, self.width, _ = self.model.get_inputs()[0].shape
        logging.info(f"Model loaded, input dimensions {self.height}x{self.width}")

        # load labels
        self.labels = load_labels(self._model_path)
        self.labels.general = [i for i in self.labels.general if i not in undesired_tags]
        self.labels.character = [i for i in self.labels.character if i not in undesired_tags]
        logging.info(f"Loaded labels from {self._model_path.joinpath('selected_tags.csv')}")

    @property
    def input_size(self) -> Tuple[int, int]:
        return (self.height, self.width)

    @property
    def input_name(self) -> str:
        return self.model.get_inputs()[0].name if self.model is not None else None

    @property
    def output_name(self) -> str:
        return self.model.get_outputs()[0].name if self.model is not None else None

    def label_image(self, images: np.ndarray) -> ImageLabels:
        # Run the ONNX model
        probs = [self.model.run([self.output_name], {self.input_name: x}) for x in images]
        # Convert to labels
        results = []
        for prob in list(probs):
            labels = list(zip(self.labels.names, prob[0][0].astype(float)))

            # First 4 labels are actually ratings: pick one with argmax
            rating_labels = dict([labels[i] for i in self.labels.rating])
            rating = max(rating_labels, key=rating_labels.get)

            # General labels, pick any where prediction confidence > threshold
            gen_labels = [labels[i] for i in self.labels.general]
            gen_labels = dict([x for x in gen_labels if x[1] > self.general_threshold])
            gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))

            # Convert to a string suitable for use as a training caption
            caption = ", ".join([x for x in gen_labels])

            booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")

            # return output
            results.append(
                ImageLabels(
                    caption=caption,
                    booru=booru,
                    rating=rating,
                    general=gen_labels,
                    character={},  # returning an empty dictionary for character labels
                    ratings=rating_labels,
                )
            )

        return results

    def __call__(self, images: List[Image.Image]) -> ImageLabels:
        # if not a list, just label the image
        for x in images:
            yield self.label_image(x)

In [56]:
def main(
    images_dir: str = "/home/irakli/foxtagger/inputs",
    base_model: str = "convnextv2",
    models_dir: str = "/home/irakli/foxtagger/models",
    force_download: bool = False,
    recursive: bool = True,
    undesired_tags: List[str] = undesired_tags_list,
    caption_extension: str = ".txt",
    frequency_tags: bool = False,
    max_data_loader_n_workers: int = 4,
    remove_underscore: bool = True,
    thresh: float = 0.35,
    general_threshold: float = None,
    character_threshold: float = None,
    debug: bool = False,
):
    base_model = base_model
    models_dir = Path(models_dir) if models_dir is not None else Path.cwd().joinpath("models")
    images_dir = Path(images_dir)
    force_download = force_download or False
    # Specify the name of your model file
    model_filename = 'model.onnx'

    recursive = recursive
    undesired_tags = set(undesired_tags)
    caption_extension = str(caption_extension).lower()
    frequency_tags = frequency_tags
    max_data_loader_n_workers = max_data_loader_n_workers

    remove_underscore = remove_underscore
    general_threshold = general_threshold or thresh
    character_threshold = character_threshold or thresh
    debug = debug

    # turn base model into a repo id and model path
    repo_id: str = get_model_repo(base_model)
    model_dir = models_dir.joinpath(repo_id.split("/")[-1])
    model_path = model_dir / model_filename  # This is the path to the model file
    

    # download the model if it doesn't exist, or if force_download is True
    print(f"Checking for {base_model}-based tagger in {model_dir}...")
    if not model_dir.is_dir() or force_download is True:
        print(f"Downloading {base_model}-based tagger from '{repo_id}'")
        snapshot_download(
            repo_id,
            local_dir_use_symlinks=False,
            local_dir=models_dir,
            cache_dir=CACHE_DIR,
            allow_patterns=["*.onnx", "*.csv"],
        )
    else:
        print("Found existing tagger model, skipping download.")

    # instantiate the dataset
    print(f"Loading images from {images_dir}...", end=" ")
    if recursive:
        image_paths = list(Path(images_dir).rglob("*.*"))
    else:
        image_paths = list(Path(images_dir).glob("*.*"))
    image_paths = [p for p in image_paths if p.suffix.lower() in IMAGE_EXTENSIONS]
    print(f"found {len(image_paths)} images to process.")
    dataset = ImageDataset(image_paths)

    # Create the data loader
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=max_data_loader_n_workers,
        collate_fn=collate_fn_remove_corrupted,
        drop_last=False,
    )

    # Create the image labeler
    labeler: ImageLabeler = ImageLabeler(
        model_path=models_dir,
        character_threshold=character_threshold,
        general_threshold=general_threshold,
        undesired_tags=undesired_tags,
    )

    # object to save tag frequencies
    tag_freqs = {}
    
    # Specify the name of your CSV output file
    csv_filename = 'output.csv'
    
    with open(csv_filename, 'w', newline='') as csvfile:
        fieldnames = ['filename', 'tags:probabilities']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        # Write the header
        writer.writeheader()

        # iterate
        for batch in tqdm(dataloader, ncols=100):
            images = [x[0] for x in batch]
            paths = [x[1] for x in batch]

            # label the images
            batch_labels = labeler.label_image(np.asarray(images))

            for image_labels, image_path in zip(batch_labels, paths):
                # save the labels
                caption = image_labels.caption
                if remove_underscore is True:
                    caption = caption.replace("_", " ")
                    
                # filter out undesired tags
                tags = caption.split(", ")
                tags = [tag for tag in tags if tag not in undesired_tags]
                caption = ", ".join(tags)
                
                # Get the relative path of the image file
                relative_path = Path(image_path).relative_to(images_dir)
                Path(image_path).with_suffix(caption_extension).write_text(caption + "\n", encoding="utf-8")
                
                # Write the filename, tag and probability to the CSV file in a single row
                general_tags_probs = ', '.join([f"{tag}:{prob}" for tag, prob in image_labels.general.items() if tag not in undesired_tags])
                writer.writerow({'filename': relative_path, 'tags:probabilities': general_tags_probs})


                # save the tag frequencies
                if frequency_tags is True:
                    for tag in tags:   # here we use filtered tags
                        if tag not in tag_freqs:
                            tag_freqs[tag] = 0
                        tag_freqs[tag] += 1

                # debug
                if debug is True:
                    print(
                        "\n".join([
                            f"{image_path}:",
                            f"  Character tags: {image_labels.character}",
                            f"    General tags: {image_labels.general}",
                        ])
                    )

        if frequency_tags:
            sorted_tags = sorted(tag_freqs.items(), key=lambda x: x[1], reverse=True)
            print("\nTag frequencies:")
            for tag, freq in sorted_tags:
                print(f"{tag}: {freq}")

    print("done!")

In [57]:
main()

Checking for convnextv2-based tagger in /home/irakli/foxtagger/models/wd-v1-4-convnextv2-tagger-v2...
Found existing tagger model, skipping download.
Loading images from /home/irakli/foxtagger/inputs... 

INFO:root:Loading model from path: /home/irakli/foxtagger/models


found 10856 images to process.


2023-05-16 00:58:07.823512985 [W:onnxruntime:Default, onnxruntime_pybind_state.cc:541 CreateExecutionProviderInstance] Failed to create CUDAExecutionProvider. Please reference https://onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html#requirements to ensure all dependencies are met.
INFO:root:Model loaded, input dimensions 448x448
INFO:root:Loaded labels from /home/irakli/foxtagger/models/selected_tags.csv


  0%|                                                                     | 0/10856 [00:00<?, ?it/s]

done!
