toolkit / jtp2_overwrite
raw
history blame
10 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
JTP2 (Joint Tagger Project 2) Image Classification Script
This script implements a multi-label classifier for furry images using the
PILOT2 model. It processes images, generates tags, and saves the results. The
model is based on a Vision Transformer architecture and uses a custom GatedHead
for classification.
Key features:
- Image preprocessing and transformation
- Model inference using PILOT2
- Tag generation with customizable threshold
- Batch processing of image directories
- Saving results as text files alongside images
Usage:
python jtp2.py <directory> [--threshold <float>]
"""
import os
import json
import argparse
from PIL import Image
import safetensors.torch
import timm
from timm.models import VisionTransformer
import torch
from torchvision.transforms import transforms
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
import pillow_jxl
class Fit(torch.nn.Module):
"""
A custom transform module for resizing and padding images.
Args:
bounds (tuple[int, int] | int): The target dimensions for the image.
interpolation (InterpolationMode): The interpolation method for resizing.
grow (bool): Whether to allow upscaling of images.
pad (float | None): The padding value to use if padding is applied.
"""
def __init__(
self,
bounds: tuple[int, int] | int,
interpolation=InterpolationMode.LANCZOS,
grow: bool = True,
pad: float | None = None
):
super().__init__()
self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
self.interpolation = interpolation
self.grow = grow
self.pad = pad
def forward(self, img: Image) -> Image:
"""
Applies the Fit transform to the input image.
Args:
img (Image): The input PIL Image.
Returns:
Image: The transformed PIL Image.
"""
wimg, himg = img.size
hbound, wbound = self.bounds
hscale = hbound / himg
wscale = wbound / wimg
if not self.grow:
hscale = min(hscale, 1.0)
wscale = min(wscale, 1.0)
scale = min(hscale, wscale)
if scale == 1.0:
return img
hnew = min(round(himg * scale), hbound)
wnew = min(round(wimg * scale), wbound)
img = TF.resize(img, (hnew, wnew), self.interpolation)
if self.pad is None:
return img
hpad = hbound - hnew
wpad = wbound - wnew
tpad = hpad // 2
bpad = hpad - tpad
lpad = wpad // 2
rpad = wpad - lpad
return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)
def __repr__(self) -> str:
"""
Returns a string representation of the Fit module.
Returns:
str: A string describing the module's parameters.
"""
return (
f"{self.__class__.__name__}(bounds={self.bounds}, "
f"interpolation={self.interpolation.value}, grow={self.grow}, "
f"pad={self.pad})"
)
class CompositeAlpha(torch.nn.Module):
"""
A module for compositing images with alpha channels over a background color.
Args:
background (tuple[float, float, float] | float): The background color to
use for compositing.
"""
def __init__(self, background: tuple[float, float, float] | float):
super().__init__()
self.background = (
(background, background, background)
if isinstance(background, float)
else background
)
self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""
Applies alpha compositing to the input image tensor.
Args:
img (torch.Tensor): The input image tensor.
Returns:
torch.Tensor: The composited image tensor.
"""
if img.shape[-3] == 3:
return img
alpha = img[..., 3, None, :, :]
img[..., :3, :, :] *= alpha
background = self.background.expand(-1, img.shape[-2], img.shape[-1])
if background.ndim == 1:
background = background[:, None, None]
elif background.ndim == 2:
background = background[None, :, :]
img[..., :3, :, :] += (1.0 - alpha) * background
return img[..., :3, :, :]
def __repr__(self) -> str:
"""
Returns a string representation of the CompositeAlpha module.
Returns:
str: A string describing the module's parameters.
"""
return f"{self.__class__.__name__}(background={self.background})"
transform = transforms.Compose([
Fit((384, 384)),
transforms.ToTensor(),
CompositeAlpha(0.5),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.CenterCrop((384, 384)),
])
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
pretrained=False,
num_classes=9083
) # type: VisionTransformer
class GatedHead(torch.nn.Module):
"""
A custom head module with gating mechanism for the classifier.
Args:
num_features (int): The number of input features.
num_classes (int): The number of output classes.
"""
def __init__(self, num_features: int, num_classes: int):
super().__init__()
self.num_classes = num_classes
self.linear = torch.nn.Linear(num_features, num_classes * 2)
self.act = torch.nn.Sigmoid()
self.gate = torch.nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies the gated head to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying the gated head.
"""
x = self.linear(x)
x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:])
return x
model.head = GatedHead(min(model.head.weight.shape), 9083)
safetensors.torch.load_model(
model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
)
if torch.cuda.is_available():
model.cuda()
if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
model.to(dtype=torch.float16, memory_format=torch.channels_last)
model.eval()
with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as file:
tags = json.load(file) # type: dict
allowed_tags = list(tags.keys())
for idx, tag in enumerate(allowed_tags):
allowed_tags[idx] = tag.replace("_", " ")
sorted_tag_score = {}
def run_classifier(image, threshold):
"""
Runs the classifier on a single image and returns tags based on the threshold.
Args:
image (PIL.Image): The input image.
threshold (float): The probability threshold for including tags.
Returns:
tuple: A tuple containing the comma-separated tags and a dictionary of
tag probabilities.
"""
global sorted_tag_score
img = image.convert('RGBA')
tensor = transform(img).unsqueeze(0)
if torch.cuda.is_available():
tensor = tensor.cuda()
if torch.cuda.get_device_capability()[0] >= 7: # tensor cores
tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
with torch.no_grad():
probits = model(tensor)[0].cpu()
values, indices = probits.topk(250)
tag_score = dict()
for i in range(indices.size(0)):
tag_score[allowed_tags[indices[i]]] = values[i].item()
sorted_tag_score = dict(
sorted(tag_score.items(), key=lambda item: item[1], reverse=True)
)
return create_tags(threshold)
def create_tags(threshold):
"""
Creates a list of tags based on the current sorted_tag_score and the given
threshold.
Args:
threshold (float): The probability threshold for including tags.
Returns:
tuple: A tuple containing the comma-separated tags and a dictionary of
filtered tag probabilities.
"""
global sorted_tag_score
filtered_tag_score = {
key: value for key, value in sorted_tag_score.items() if value > threshold
}
text_no_impl = ", ".join(filtered_tag_score.keys())
return text_no_impl, filtered_tag_score
def process_directory(directory, threshold):
"""
Processes all images in a directory and its subdirectories, generating tags
for each image.
Args:
directory (str): The path to the directory containing images.
threshold (float): The probability threshold for including tags.
Returns:
dict: A dictionary mapping image paths to their generated tags.
"""
results = {}
for root, _, files in os.walk(directory):
for file in files:
if file.lower().endswith(('.jpg', '.jpeg', '.png', '.jxl')):
image_path = os.path.join(root, file)
image = Image.open(image_path)
tags, _ = run_classifier(image, threshold)
results[image_path] = tags
# Save tags to a text file with the same name as the image
text_file_path = os.path.splitext(image_path)[0] + ".txt"
with open(text_file_path, "w", encoding="utf-8") as text_file:
text_file.write(tags)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run inference on a directory of images."
)
parser.add_argument("directory", type=str, help="Target directory containing images.")
parser.add_argument(
"--threshold", type=float, default=0.2, help="Threshold for tag filtering."
)
args = parser.parse_args()
results = process_directory(args.directory, args.threshold)
for image_path, tags in results.items():
print(f"{image_path}: {tags}")