|
|
|
|
|
|
|
""" |
|
JTP2 (Joint Tagger Project 2) Image Classification Script |
|
This script implements a multi-label classifier for furry images using the |
|
PILOT2 model. It processes images, generates tags, and saves the results. The |
|
model is based on a Vision Transformer architecture and uses a custom GatedHead |
|
for classification. |
|
|
|
Key features: |
|
- Image preprocessing and transformation |
|
- Model inference using PILOT2 |
|
- Tag generation with customizable threshold |
|
- Batch processing of image directories |
|
- Saving results as text files alongside images |
|
|
|
Usage: |
|
python jtp2.py <directory> [--threshold <float>] |
|
""" |
|
|
|
import os |
|
import json |
|
import argparse |
|
from PIL import Image |
|
import safetensors.torch |
|
import timm |
|
from timm.models import VisionTransformer |
|
import torch |
|
from torchvision.transforms import transforms |
|
from torchvision.transforms import InterpolationMode |
|
import torchvision.transforms.functional as TF |
|
import pillow_jxl |
|
from multiprocessing import Pool, cpu_count |
|
from itertools import islice |
|
import multiprocessing |
|
|
|
try: |
|
import pillow_avif |
|
except ImportError: |
|
pass |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
multiprocessing.set_start_method('spawn', force=True) |
|
|
|
|
|
global_model = None |
|
global_transform = None |
|
global_tags = None |
|
global_allowed_tags = None |
|
global_threshold = None |
|
|
|
def initializer_worker(threshold: float, cpu: bool, use_no_grad: bool): |
|
""" |
|
Initializer for each worker process. |
|
Loads the model and related resources as global variables. |
|
""" |
|
global global_model, global_transform, global_tags, global_allowed_tags, global_threshold |
|
|
|
global_threshold = threshold |
|
|
|
|
|
model = timm.create_model( |
|
"vit_so400m_patch14_siglip_384.webli", |
|
pretrained=False, |
|
num_classes=9083 |
|
) |
|
model.head = GatedHead(min(model.head.weight.shape), 9083) |
|
safetensors.torch.load_model( |
|
model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors" |
|
) |
|
|
|
|
|
if torch.cuda.is_available() and not cpu: |
|
model.cuda() |
|
if torch.cuda.get_device_capability()[0] >= 7: |
|
model.to(dtype=torch.float16, memory_format=torch.channels_last) |
|
|
|
if use_no_grad: |
|
model = torch.no_grad()(model) |
|
|
|
model.eval() |
|
|
|
|
|
with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as file: |
|
tags = json.load(file) |
|
allowed_tags = list(tags.keys()) |
|
for idx, tag in enumerate(allowed_tags): |
|
allowed_tags[idx] = tag.replace("_", " ") |
|
|
|
global_model = model |
|
global_transform = transform |
|
global_tags = tags |
|
global_allowed_tags = allowed_tags |
|
|
|
print("[Initializer] Worker setup complete.") |
|
|
|
def process_image_worker(args): |
|
""" |
|
Process a single image using the globally loaded model. |
|
""" |
|
global global_model, global_transform, global_tags, global_allowed_tags, global_threshold |
|
|
|
image_path, text_file_path = args |
|
|
|
|
|
if os.path.exists(text_file_path): |
|
print(f"Skipping {image_path} - caption file already exists") |
|
return |
|
|
|
try: |
|
print(f"Processing {image_path}...") |
|
image = Image.open(image_path) |
|
img = image.convert('RGBA') |
|
tensor = global_transform(img).unsqueeze(0) |
|
|
|
|
|
use_cuda = next(global_model.parameters()).is_cuda |
|
|
|
if use_cuda: |
|
tensor = tensor.cuda() |
|
if torch.cuda.get_device_capability()[0] >= 7: |
|
tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last) |
|
|
|
with torch.no_grad(): |
|
probits = global_model(tensor)[0].cpu() |
|
values, indices = probits.topk(250) |
|
|
|
tag_score = dict() |
|
for i in range(indices.size(0)): |
|
tag_score[global_allowed_tags[indices[i]]] = values[i].item() |
|
|
|
sorted_tag_score = dict( |
|
sorted(tag_score.items(), key=lambda item: item[1], reverse=True) |
|
) |
|
|
|
filtered_tag_score = { |
|
key: value for key, value in sorted_tag_score.items() |
|
if value > global_threshold |
|
} |
|
tags = ", ".join(filtered_tag_score.keys()) |
|
|
|
|
|
with open(text_file_path, "w", encoding="utf-8") as text_file: |
|
text_file.write(tags) |
|
|
|
except Exception as e: |
|
print(f"Error processing {image_path}: {e}") |
|
|
|
def process_directory(directory, threshold, cpu=False, use_no_grad=False): |
|
""" |
|
Processes all images in a directory using multiple worker processes. |
|
Limits CUDA workers to 6 while allowing more CPU workers if needed. |
|
""" |
|
|
|
image_paths = [] |
|
for root, _, files in os.walk(directory): |
|
for file in files: |
|
if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.jxl', '.avif')): |
|
image_path = os.path.join(root, file) |
|
text_file_path = os.path.splitext(image_path)[0] + ".tags" |
|
image_paths.append((image_path, text_file_path)) |
|
|
|
if not image_paths: |
|
print(f"No images found in {directory}") |
|
return |
|
|
|
|
|
batches = list(batch_iterator(image_paths, 16)) |
|
|
|
|
|
if cpu: |
|
|
|
num_processes = min(cpu_count() // 2, len(batches)) |
|
else: |
|
|
|
num_processes = min(6, len(batches)) |
|
|
|
print(f"Found {len(image_paths)} images in {len(batches)} batches") |
|
print(f"Processing using {num_processes} {'CPU' if cpu else 'CUDA'} processes...") |
|
|
|
|
|
with Pool( |
|
processes=num_processes, |
|
initializer=initializer_worker, |
|
initargs=(threshold, cpu, use_no_grad) |
|
) as pool: |
|
pool.map(process_image_worker, image_paths) |
|
|
|
def batch_iterator(iterable, batch_size): |
|
""" |
|
Yield successive batches from an iterable. |
|
""" |
|
iterator = iter(iterable) |
|
while batch := list(islice(iterator, batch_size)): |
|
yield batch |
|
|
|
|
|
class Fit(torch.nn.Module): |
|
""" |
|
A custom transform module for resizing and padding images. |
|
Args: |
|
bounds (tuple[int, int] | int): The target dimensions for the image. |
|
interpolation (InterpolationMode): The interpolation method for resizing. |
|
grow (bool): Whether to allow upscaling of images. |
|
pad (float | None): The padding value to use if padding is applied. |
|
""" |
|
def __init__( |
|
self, |
|
bounds: tuple[int, int] | int, |
|
interpolation=InterpolationMode.LANCZOS, |
|
grow: bool = True, |
|
pad: float | None = None |
|
): |
|
super().__init__() |
|
self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds |
|
self.interpolation = interpolation |
|
self.grow = grow |
|
self.pad = pad |
|
|
|
def forward(self, img: Image) -> Image: |
|
""" |
|
Applies the Fit transform to the input image. |
|
Args: |
|
img (Image): The input PIL Image. |
|
Returns: |
|
Image: The transformed PIL Image. |
|
""" |
|
wimg, himg = img.size |
|
hbound, wbound = self.bounds |
|
hscale = hbound / himg |
|
wscale = wbound / wimg |
|
if not self.grow: |
|
hscale = min(hscale, 1.0) |
|
wscale = min(wscale, 1.0) |
|
scale = min(hscale, wscale) |
|
if scale == 1.0: |
|
return img |
|
hnew = min(round(himg * scale), hbound) |
|
wnew = min(round(wimg * scale), wbound) |
|
img = TF.resize(img, (hnew, wnew), self.interpolation) |
|
if self.pad is None: |
|
return img |
|
hpad = hbound - hnew |
|
wpad = wbound - wnew |
|
tpad = hpad |
|
bpad = hpad - tpad |
|
lpad = wpad |
|
rpad = wpad - lpad |
|
return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad) |
|
def __repr__(self) -> str: |
|
""" |
|
Returns a string representation of the Fit module. |
|
Returns: |
|
str: A string describing the module's parameters. |
|
""" |
|
return ( |
|
f"{self.__class__.__name__}(bounds={self.bounds}, " |
|
f"interpolation={self.interpolation.value}, grow={self.grow}, " |
|
f"pad={self.pad})" |
|
) |
|
|
|
|
|
class CompositeAlpha(torch.nn.Module): |
|
""" |
|
A module for compositing images with alpha channels over a background color. |
|
Args: |
|
background (tuple[float, float, float] | float): The background color to |
|
use for compositing. |
|
""" |
|
def __init__(self, background: tuple[float, float, float] | float): |
|
super().__init__() |
|
self.background = ( |
|
(background, background, background) |
|
if isinstance(background, float) |
|
else background |
|
) |
|
self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2) |
|
|
|
def forward(self, img: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Applies alpha compositing to the input image tensor. |
|
Args: |
|
img (torch.Tensor): The input image tensor. |
|
Returns: |
|
torch.Tensor: The composited image tensor. |
|
""" |
|
if img.shape[-3] == 3: |
|
return img |
|
alpha = img[..., 3, None, :, :] |
|
img[..., :3, :, :] *= alpha |
|
background = self.background.expand(-1, img.shape[-2], img.shape[-1]) |
|
if background.ndim == 1: |
|
background = background[:, None, None] |
|
elif background.ndim == 2: |
|
background = background[None, :, :] |
|
img[..., :3, :, :] += (1.0 - alpha) * background |
|
return img[..., :3, :, :] |
|
|
|
def __repr__(self) -> str: |
|
""" |
|
Returns a string representation of the CompositeAlpha module. |
|
Returns: |
|
str: A string describing the module's parameters. |
|
""" |
|
return f"{self.__class__.__name__}(background={self.background})" |
|
|
|
|
|
transform = transforms.Compose([ |
|
Fit((384, 384)), |
|
transforms.ToTensor(), |
|
CompositeAlpha(0.5), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
transforms.CenterCrop((384, 384)), |
|
]) |
|
model = timm.create_model( |
|
"vit_so400m_patch14_siglip_384.webli", |
|
pretrained=False, |
|
num_classes=9083 |
|
) |
|
|
|
|
|
class GatedHead(torch.nn.Module): |
|
""" |
|
A custom head module with gating mechanism for the classifier. |
|
Args: |
|
num_features (int): The number of input features. |
|
num_classes (int): The number of output classes. |
|
""" |
|
def __init__(self, num_features: int, num_classes: int): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
self.linear = torch.nn.Linear(num_features, num_classes * 2) |
|
self.act = torch.nn.Sigmoid() |
|
self.gate = torch.nn.Sigmoid() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Applies the gated head to the input tensor. |
|
Args: |
|
x (torch.Tensor): The input tensor. |
|
Returns: |
|
torch.Tensor: The output tensor after applying the gated head. |
|
""" |
|
x = self.linear(x) |
|
x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:]) |
|
return x |
|
|
|
|
|
|
|
model.head = GatedHead(min(model.head.weight.shape), 9083) |
|
safetensors.torch.load_model( |
|
model, "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors" |
|
) |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
description="Run inference on a directory of images." |
|
) |
|
parser.add_argument("directory", type=str, help="Target directory containing images.") |
|
parser.add_argument( |
|
"--threshold", type=float, default=0.2, help="Threshold for tag filtering." |
|
) |
|
parser.add_argument( |
|
"--cpu", action="store_true", help="Force CPU inference instead of CUDA" |
|
) |
|
parser.add_argument( |
|
"--no_grad", action="store_true", help="Enable torch.no_grad() for inference" |
|
) |
|
args = parser.parse_args() |
|
|
|
|
|
if torch.cuda.is_available() and not args.cpu: |
|
model.cuda() |
|
if torch.cuda.get_device_capability()[0] >= 7: |
|
model.to(dtype=torch.float16, memory_format=torch.channels_last) |
|
model.eval() |
|
with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as file: |
|
tags = json.load(file) |
|
allowed_tags = list(tags.keys()) |
|
for idx, tag in enumerate(allowed_tags): |
|
allowed_tags[idx] = tag.replace("_", " ") |
|
sorted_tag_score = {} |
|
|
|
|
|
def run_classifier(image, threshold): |
|
""" |
|
Runs the classifier on a single image and returns tags based on the threshold. |
|
Args: |
|
image (PIL.Image): The input image. |
|
threshold (float): The probability threshold for including tags. |
|
Returns: |
|
tuple: A tuple containing the comma-separated tags and a dictionary of |
|
tag probabilities. |
|
""" |
|
global sorted_tag_score |
|
img = image.convert('RGBA') |
|
tensor = transform(img).unsqueeze(0) |
|
if torch.cuda.is_available() and not args.cpu: |
|
tensor = tensor.cuda() |
|
if torch.cuda.get_device_capability()[0] >= 7: |
|
tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last) |
|
|
|
if args.no_grad: |
|
with torch.no_grad(): |
|
probits = model(tensor)[0].cpu() |
|
else: |
|
probits = model(tensor)[0].cpu() |
|
|
|
values, indices = probits.topk(250) |
|
tag_score = dict() |
|
for i in range(indices.size(0)): |
|
tag_score[allowed_tags[indices[i]]] = values[i].item() |
|
sorted_tag_score = dict( |
|
sorted(tag_score.items(), key=lambda item: item[1], reverse=True) |
|
) |
|
return create_tags(threshold) |
|
|
|
def create_tags(threshold): |
|
""" |
|
Creates a list of tags based on the current sorted_tag_score and the given |
|
threshold. |
|
Args: |
|
threshold (float): The probability threshold for including tags. |
|
Returns: |
|
tuple: A tuple containing the comma-separated tags and a dictionary of |
|
filtered tag probabilities. |
|
""" |
|
global sorted_tag_score |
|
filtered_tag_score = { |
|
key: value for key, value in sorted_tag_score.items() if value > threshold |
|
} |
|
text_no_impl = ", ".join(filtered_tag_score.keys()) |
|
return text_no_impl, filtered_tag_score |
|
|
|
if __name__ == "__main__": |
|
process_directory(args.directory, args.threshold, args.cpu, args.no_grad) |
|
|
|
|
|
|