| |
| |
|
|
| """ |
| JTP2 (Joint Tagger Project 2) Image Classification Script |
| |
| This script implements a multi-label classifier for furry images using the |
| PILOT2 model. It processes images, generates tags, and saves the results. The |
| model is based on a Vision Transformer architecture and uses a custom GatedHead |
| for classification. |
| |
| Key features: |
| - Image preprocessing and transformation |
| - Model inference using PILOT2 |
| - Tag generation with customizable threshold |
| - Batch processing of image directories |
| - Saving results as text files alongside images |
| |
| Usage: |
| python jtp2.py <directory> [--threshold <float>] |
| """ |
|
|
| import os |
| import json |
| import argparse |
| from PIL import Image |
| import safetensors.torch |
| import timm |
| from timm.models import VisionTransformer |
| import torch |
| from torchvision.transforms import transforms |
| from torchvision.transforms import InterpolationMode |
| import torchvision.transforms.functional as TF |
| import pillow_jxl |
|
|
|
|
| class Fit(torch.nn.Module): |
| """ |
| A custom transform module for resizing and padding images. |
| |
| Args: |
| bounds (tuple[int, int] | int): The target dimensions for the image. |
| interpolation (InterpolationMode): The interpolation method for resizing. |
| grow (bool): Whether to allow upscaling of images. |
| pad (float | None): The padding value to use if padding is applied. |
| """ |
|
|
| def __init__( |
| self, |
| bounds: tuple[int, int] | int, |
| interpolation=InterpolationMode.LANCZOS, |
| grow: bool = True, |
| pad: float | None = None |
| ): |
| super().__init__() |
| self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds |
| self.interpolation = interpolation |
| self.grow = grow |
| self.pad = pad |
|
|
| def forward(self, img: Image) -> Image: |
| """ |
| Applies the Fit transform to the input image. |
| |
| Args: |
| img (Image): The input PIL Image. |
| |
| Returns: |
| Image: The transformed PIL Image. |
| """ |
| wimg, himg = img.size |
| hbound, wbound = self.bounds |
| hscale = hbound / himg |
| wscale = wbound / wimg |
| if not self.grow: |
| hscale = min(hscale, 1.0) |
| wscale = min(wscale, 1.0) |
| scale = min(hscale, wscale) |
| if scale == 1.0: |
| return img |
| hnew = min(round(himg * scale), hbound) |
| wnew = min(round(wimg * scale), wbound) |
| img = TF.resize(img, (hnew, wnew), self.interpolation) |
| if self.pad is None: |
| return img |
| hpad = hbound - hnew |
| wpad = wbound - wnew |
| tpad = hpad // 2 |
| bpad = hpad - tpad |
| lpad = wpad // 2 |
| rpad = wpad - lpad |
| return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad) |
|
|
| def __repr__(self) -> str: |
| """ |
| Returns a string representation of the Fit module. |
| |
| Returns: |
| str: A string describing the module's parameters. |
| """ |
| return ( |
| f"{self.__class__.__name__}(bounds={self.bounds}, " |
| f"interpolation={self.interpolation.value}, grow={self.grow}, " |
| f"pad={self.pad})" |
| ) |
|
|
|
|
| class CompositeAlpha(torch.nn.Module): |
| """ |
| A module for compositing images with alpha channels over a background color. |
| |
| Args: |
| background (tuple[float, float, float] | float): The background color to |
| use for compositing. |
| """ |
|
|
| def __init__(self, background: tuple[float, float, float] | float): |
| super().__init__() |
| self.background = ( |
| (background, background, background) |
| if isinstance(background, float) |
| else background |
| ) |
| self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2) |
|
|
| def forward(self, img: torch.Tensor) -> torch.Tensor: |
| """ |
| Applies alpha compositing to the input image tensor. |
| |
| Args: |
| img (torch.Tensor): The input image tensor. |
| |
| Returns: |
| torch.Tensor: The composited image tensor. |
| """ |
| if img.shape[-3] == 3: |
| return img |
| alpha = img[..., 3, None, :, :] |
| img[..., :3, :, :] *= alpha |
| background = self.background.expand(-1, img.shape[-2], img.shape[-1]) |
| if background.ndim == 1: |
| background = background[:, None, None] |
| elif background.ndim == 2: |
| background = background[None, :, :] |
| img[..., :3, :, :] += (1.0 - alpha) * background |
| return img[..., :3, :, :] |
|
|
| def __repr__(self) -> str: |
| """ |
| Returns a string representation of the CompositeAlpha module. |
| |
| Returns: |
| str: A string describing the module's parameters. |
| """ |
| return f"{self.__class__.__name__}(background={self.background})" |
|
|
|
|
| transform = transforms.Compose([ |
| Fit((384, 384)), |
| transforms.ToTensor(), |
| CompositeAlpha(0.5), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| transforms.CenterCrop((384, 384)), |
| ]) |
|
|
| model = timm.create_model( |
| "vit_so400m_patch14_siglip_384.webli", |
| pretrained=False, |
| num_classes=9083 |
| ) |
|
|
|
|
| class GatedHead(torch.nn.Module): |
| """ |
| A custom head module with gating mechanism for the classifier. |
| |
| Args: |
| num_features (int): The number of input features. |
| num_classes (int): The number of output classes. |
| """ |
|
|
| def __init__(self, num_features: int, num_classes: int): |
| super().__init__() |
| self.num_classes = num_classes |
| self.linear = torch.nn.Linear(num_features, num_classes * 2) |
| self.act = torch.nn.Sigmoid() |
| self.gate = torch.nn.Sigmoid() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Applies the gated head to the input tensor. |
| |
| Args: |
| x (torch.Tensor): The input tensor. |
| |
| Returns: |
| torch.Tensor: The output tensor after applying the gated head. |
| """ |
| x = self.linear(x) |
| x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:]) |
| return x |
|
|
|
|
| model.head = GatedHead(min(model.head.weight.shape), 9083) |
| safetensors.torch.load_model( |
| model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors" |
| ) |
|
|
| if torch.cuda.is_available(): |
| model.cuda() |
| if torch.cuda.get_device_capability()[0] >= 7: |
| model.to(dtype=torch.float16, memory_format=torch.channels_last) |
|
|
| model.eval() |
|
|
| with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as file: |
| tags = json.load(file) |
| allowed_tags = list(tags.keys()) |
|
|
| for idx, tag in enumerate(allowed_tags): |
| allowed_tags[idx] = tag.replace("_", " ") |
|
|
| sorted_tag_score = {} |
|
|
|
|
| def run_classifier(image, threshold): |
| """ |
| Runs the classifier on a single image and returns tags based on the threshold. |
| |
| Args: |
| image (PIL.Image): The input image. |
| threshold (float): The probability threshold for including tags. |
| |
| Returns: |
| tuple: A tuple containing the comma-separated tags and a dictionary of |
| tag probabilities. |
| """ |
| global sorted_tag_score |
| img = image.convert('RGBA') |
| tensor = transform(img).unsqueeze(0) |
| if torch.cuda.is_available(): |
| tensor = tensor.cuda() |
| if torch.cuda.get_device_capability()[0] >= 7: |
| tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last) |
| with torch.no_grad(): |
| probits = model(tensor)[0].cpu() |
| values, indices = probits.topk(250) |
| tag_score = dict() |
| for i in range(indices.size(0)): |
| tag_score[allowed_tags[indices[i]]] = values[i].item() |
| sorted_tag_score = dict( |
| sorted(tag_score.items(), key=lambda item: item[1], reverse=True) |
| ) |
| return create_tags(threshold) |
|
|
|
|
| def create_tags(threshold): |
| """ |
| Creates a list of tags based on the current sorted_tag_score and the given |
| threshold. |
| |
| Args: |
| threshold (float): The probability threshold for including tags. |
| |
| Returns: |
| tuple: A tuple containing the comma-separated tags and a dictionary of |
| filtered tag probabilities. |
| """ |
| global sorted_tag_score |
| filtered_tag_score = { |
| key: value for key, value in sorted_tag_score.items() if value > threshold |
| } |
| text_no_impl = ", ".join(filtered_tag_score.keys()) |
| return text_no_impl, filtered_tag_score |
|
|
|
|
| def process_directory(directory, threshold): |
| """ |
| Processes all images in a directory and its subdirectories, generating tags |
| for each image. |
| |
| Args: |
| directory (str): The path to the directory containing images. |
| threshold (float): The probability threshold for including tags. |
| |
| Returns: |
| dict: A dictionary mapping image paths to their generated tags. |
| """ |
| results = {} |
| for root, _, files in os.walk(directory): |
| for file in files: |
| if file.lower().endswith(('.jpg', '.jpeg', '.png', '.jxl')): |
| image_path = os.path.join(root, file) |
| image = Image.open(image_path) |
| tags, _ = run_classifier(image, threshold) |
| results[image_path] = tags |
| |
| text_file_path = os.path.splitext(image_path)[0] + ".txt" |
| with open(text_file_path, "w", encoding="utf-8") as text_file: |
| text_file.write(tags) |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Run inference on a directory of images." |
| ) |
| parser.add_argument("directory", type=str, help="Target directory containing images.") |
| parser.add_argument( |
| "--threshold", type=float, default=0.2, help="Threshold for tag filtering." |
| ) |
| args = parser.parse_args() |
|
|
| results = process_directory(args.directory, args.threshold) |
| for image_path, tags in results.items(): |
| print(f"{image_path}: {tags}") |
|
|