toolkit / caption /jtp2.py
k4d3's picture
avif
c2699bb
raw
history blame
14.8 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
JTP2 (Joint Tagger Project 2) Image Classification Script
This script implements a multi-label classifier for furry images using the
PILOT2 model. It processes images, generates tags, and saves the results. The
model is based on a Vision Transformer architecture and uses a custom GatedHead
for classification.
Key features:
- Image preprocessing and transformation
- Model inference using PILOT2
- Tag generation with customizable threshold
- Batch processing of image directories
- Saving results as text files alongside images
Usage:
python jtp2.py <directory> [--threshold <float>]
"""
import os
import json
import argparse
from PIL import Image
import safetensors.torch
import timm
from timm.models import VisionTransformer
import torch
from torchvision.transforms import transforms
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
import pillow_jxl # type: ignore
from multiprocessing import Pool, cpu_count
from itertools import islice
import multiprocessing
try:
import pillow_avif
except ImportError:
pass
torch.set_grad_enabled(False)
# Set start method to spawn for CUDA compatibility
multiprocessing.set_start_method('spawn', force=True)
# Add these global variables for workers
global_model = None
global_transform = None
global_tags = None
global_allowed_tags = None
global_threshold = None
def initializer_worker(threshold: float, cpu: bool, use_no_grad: bool):
"""
Initializer for each worker process.
Loads the model and related resources as global variables.
"""
global global_model, global_transform, global_tags, global_allowed_tags, global_threshold
global_threshold = threshold
# Initialize model
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
pretrained=False,
num_classes=9083
)
model.head = GatedHead(min(model.head.weight.shape), 9083)
safetensors.torch.load_model(
model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
)
# Set up CUDA if available and requested
if torch.cuda.is_available() and not cpu:
model.cuda()
if torch.cuda.get_device_capability()[0] >= 7:
model.to(dtype=torch.float16, memory_format=torch.channels_last)
if use_no_grad:
model = torch.no_grad()(model)
model.eval()
# Load tags
with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as file:
tags = json.load(file)
allowed_tags = list(tags.keys())
for idx, tag in enumerate(allowed_tags):
allowed_tags[idx] = tag.replace("_", " ")
global_model = model
global_transform = transform
global_tags = tags
global_allowed_tags = allowed_tags
print("[Initializer] Worker setup complete.")
def process_image_worker(args):
"""
Process a single image using the globally loaded model.
"""
global global_model, global_transform, global_tags, global_allowed_tags, global_threshold
image_path, text_file_path = args
# Skip if output file exists
if os.path.exists(text_file_path):
print(f"Skipping {image_path} - caption file already exists")
return
try:
print(f"Processing {image_path}...")
image = Image.open(image_path)
img = image.convert('RGBA')
tensor = global_transform(img).unsqueeze(0)
# Get CUDA status from the model instead of undefined cpu variable
use_cuda = next(global_model.parameters()).is_cuda
if use_cuda:
tensor = tensor.cuda()
if torch.cuda.get_device_capability()[0] >= 7:
tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
with torch.no_grad():
probits = global_model(tensor)[0].cpu()
values, indices = probits.topk(250)
tag_score = dict()
for i in range(indices.size(0)):
tag_score[global_allowed_tags[indices[i]]] = values[i].item()
sorted_tag_score = dict(
sorted(tag_score.items(), key=lambda item: item[1], reverse=True)
)
filtered_tag_score = {
key: value for key, value in sorted_tag_score.items()
if value > global_threshold
}
tags = ", ".join(filtered_tag_score.keys())
# Save tags
with open(text_file_path, "w", encoding="utf-8") as text_file:
text_file.write(tags)
except Exception as e:
print(f"Error processing {image_path}: {e}")
def process_directory(directory, threshold, cpu=False, use_no_grad=False):
"""
Processes all images in a directory using multiple worker processes.
Limits CUDA workers to 6 while allowing more CPU workers if needed.
"""
# Collect image paths and corresponding output paths
image_paths = []
for root, _, files in os.walk(directory):
for file in files:
if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.jxl', '.avif')):
image_path = os.path.join(root, file)
text_file_path = os.path.splitext(image_path)[0] + ".tags"
image_paths.append((image_path, text_file_path))
if not image_paths:
print(f"No images found in {directory}")
return
# Create batches of images
batches = list(batch_iterator(image_paths, 16))
# Determine number of processes based on CPU/CUDA
if cpu:
# If using CPU, can use more processes
num_processes = min(cpu_count() // 2, len(batches))
else:
# If using CUDA, limit to 6 processes
num_processes = min(6, len(batches))
print(f"Found {len(image_paths)} images in {len(batches)} batches")
print(f"Processing using {num_processes} {'CPU' if cpu else 'CUDA'} processes...")
# Process images using multiple workers
with Pool(
processes=num_processes,
initializer=initializer_worker,
initargs=(threshold, cpu, use_no_grad)
) as pool:
pool.map(process_image_worker, image_paths)
def batch_iterator(iterable, batch_size):
"""
Yield successive batches from an iterable.
"""
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch
class Fit(torch.nn.Module):
"""
A custom transform module for resizing and padding images.
Args:
bounds (tuple[int, int] | int): The target dimensions for the image.
interpolation (InterpolationMode): The interpolation method for resizing.
grow (bool): Whether to allow upscaling of images.
pad (float | None): The padding value to use if padding is applied.
"""
def __init__(
self,
bounds: tuple[int, int] | int,
interpolation=InterpolationMode.LANCZOS,
grow: bool = True,
pad: float | None = None
):
super().__init__()
self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
self.interpolation = interpolation
self.grow = grow
self.pad = pad
def forward(self, img: Image) -> Image:
"""
Applies the Fit transform to the input image.
Args:
img (Image): The input PIL Image.
Returns:
Image: The transformed PIL Image.
"""
wimg, himg = img.size
hbound, wbound = self.bounds
hscale = hbound / himg
wscale = wbound / wimg
if not self.grow:
hscale = min(hscale, 1.0)
wscale = min(wscale, 1.0)
scale = min(hscale, wscale)
if scale == 1.0:
return img
hnew = min(round(himg * scale), hbound)
wnew = min(round(wimg * scale), wbound)
img = TF.resize(img, (hnew, wnew), self.interpolation)
if self.pad is None:
return img
hpad = hbound - hnew
wpad = wbound - wnew
tpad = hpad
bpad = hpad - tpad
lpad = wpad
rpad = wpad - lpad
return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)
def __repr__(self) -> str:
"""
Returns a string representation of the Fit module.
Returns:
str: A string describing the module's parameters.
"""
return (
f"{self.__class__.__name__}(bounds={self.bounds}, "
f"interpolation={self.interpolation.value}, grow={self.grow}, "
f"pad={self.pad})"
)
class CompositeAlpha(torch.nn.Module):
"""
A module for compositing images with alpha channels over a background color.
Args:
background (tuple[float, float, float] | float): The background color to
use for compositing.
"""
def __init__(self, background: tuple[float, float, float] | float):
super().__init__()
self.background = (
(background, background, background)
if isinstance(background, float)
else background
)
self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)
def forward(self, img: torch.Tensor) -> torch.Tensor:
"""
Applies alpha compositing to the input image tensor.
Args:
img (torch.Tensor): The input image tensor.
Returns:
torch.Tensor: The composited image tensor.
"""
if img.shape[-3] == 3:
return img
alpha = img[..., 3, None, :, :]
img[..., :3, :, :] *= alpha
background = self.background.expand(-1, img.shape[-2], img.shape[-1])
if background.ndim == 1:
background = background[:, None, None]
elif background.ndim == 2:
background = background[None, :, :]
img[..., :3, :, :] += (1.0 - alpha) * background
return img[..., :3, :, :]
def __repr__(self) -> str:
"""
Returns a string representation of the CompositeAlpha module.
Returns:
str: A string describing the module's parameters.
"""
return f"{self.__class__.__name__}(background={self.background})"
transform = transforms.Compose([
Fit((384, 384)),
transforms.ToTensor(),
CompositeAlpha(0.5),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
transforms.CenterCrop((384, 384)),
])
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
pretrained=False,
num_classes=9083
) # type: VisionTransformer
class GatedHead(torch.nn.Module):
"""
A custom head module with gating mechanism for the classifier.
Args:
num_features (int): The number of input features.
num_classes (int): The number of output classes.
"""
def __init__(self, num_features: int, num_classes: int):
super().__init__()
self.num_classes = num_classes
self.linear = torch.nn.Linear(num_features, num_classes * 2)
self.act = torch.nn.Sigmoid()
self.gate = torch.nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies the gated head to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying the gated head.
"""
x = self.linear(x)
x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:])
return x
# Move model loading and CUDA setup after argument parsing
model.head = GatedHead(min(model.head.weight.shape), 9083)
safetensors.torch.load_model(
model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors"
)
# Create argument parser first
parser = argparse.ArgumentParser(
description="Run inference on a directory of images."
)
parser.add_argument("directory", type=str, help="Target directory containing images.")
parser.add_argument(
"--threshold", type=float, default=0.2, help="Threshold for tag filtering."
)
parser.add_argument(
"--cpu", action="store_true", help="Force CPU inference instead of CUDA"
)
parser.add_argument(
"--no_grad", action="store_true", help="Enable torch.no_grad() for inference"
)
args = parser.parse_args()
# Now we can use args
if torch.cuda.is_available() and not args.cpu:
model.cuda()
if torch.cuda.get_device_capability()[0] >= 7:
model.to(dtype=torch.float16, memory_format=torch.channels_last)
model.eval()
with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as file:
tags = json.load(file) # type: dict
allowed_tags = list(tags.keys())
for idx, tag in enumerate(allowed_tags):
allowed_tags[idx] = tag.replace("_", " ")
sorted_tag_score = {}
def run_classifier(image, threshold):
"""
Runs the classifier on a single image and returns tags based on the threshold.
Args:
image (PIL.Image): The input image.
threshold (float): The probability threshold for including tags.
Returns:
tuple: A tuple containing the comma-separated tags and a dictionary of
tag probabilities.
"""
global sorted_tag_score
img = image.convert('RGBA')
tensor = transform(img).unsqueeze(0)
if torch.cuda.is_available() and not args.cpu:
tensor = tensor.cuda()
if torch.cuda.get_device_capability()[0] >= 7:
tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
if args.no_grad:
with torch.no_grad():
probits = model(tensor)[0].cpu()
else:
probits = model(tensor)[0].cpu()
values, indices = probits.topk(250)
tag_score = dict()
for i in range(indices.size(0)):
tag_score[allowed_tags[indices[i]]] = values[i].item()
sorted_tag_score = dict(
sorted(tag_score.items(), key=lambda item: item[1], reverse=True)
)
return create_tags(threshold)
def create_tags(threshold):
"""
Creates a list of tags based on the current sorted_tag_score and the given
threshold.
Args:
threshold (float): The probability threshold for including tags.
Returns:
tuple: A tuple containing the comma-separated tags and a dictionary of
filtered tag probabilities.
"""
global sorted_tag_score
filtered_tag_score = {
key: value for key, value in sorted_tag_score.items() if value > threshold
}
text_no_impl = ", ".join(filtered_tag_score.keys())
return text_no_impl, filtered_tag_score
if __name__ == "__main__":
process_directory(args.directory, args.threshold, args.cpu, args.no_grad)