#!/usr/bin/env python # -*- coding: utf-8 -*- # # caption/jtp2.py import sys import os # Add the project root directory to sys.path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) """ JTP2 (Joint Tagger Project 2) Image Classification Script This script implements a multi-label classifier for furry images using the PILOT2 model. It processes images, generates tags, and saves the results. The model is based on a Vision Transformer architecture and uses a custom GatedHead for classification. Key features: - Image preprocessing and transformation - Model inference using PILOT2 - Tag generation with customizable threshold - Batch processing of image directories - Saving results as text files alongside images Usage: python jtp2.py [--gen_threshold ] [--char_threshold ] """ import json import argparse from PIL import Image import safetensors.torch import timm import torch from torchvision.transforms import transforms from torchvision.transforms import InterpolationMode import pillow_jxl # type: ignore from multiprocessing import cpu_count import multiprocessing from utils.batch_processor import BatchProcessor, BatchOptions from pathlib import Path torch.set_grad_enabled(False) # Set start method to spawn for CUDA compatibility multiprocessing.set_start_method('spawn', force=True) # Define image transform transform = transforms.Compose([ transforms.Resize((384, 384), interpolation=InterpolationMode.LANCZOS), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) class GatedHead(torch.nn.Module): def __init__(self, in_features: int, num_classes: int): super().__init__() self.linear = torch.nn.Linear(in_features, num_classes) self.gate = torch.nn.Parameter(torch.ones(num_classes)) def forward(self, x): return self.linear(x) * self.gate.unsqueeze(0) # Adjust gating logic as needed @property def weight(self): return self.linear.weight @weight.setter def weight(self, value): self.linear.weight = value @property def bias(self): return self.linear.bias @bias.setter def bias(self, value): self.linear.bias = value class JTP2Processor(BatchProcessor[Path, None]): """JTP2 image processor implementation""" def __init__(self, opts: BatchOptions): super().__init__(opts) self.model = None self.tags = None self.allowed_tags = None self.transform = transform # Using existing transform self.load_models() def load_models(self) -> None: """Load required models and resources""" self.model = timm.create_model( "vit_so400m_patch14_siglip_384.webli", pretrained=False, num_classes=18166 # Align with state_dict ) # Replace the model's head with the custom GatedHead self.model.head = GatedHead( in_features=self.model.head.in_features, # Access directly from Linear layer num_classes=self.model.head.out_features ) # Load the state_dict using load_file state_dict = safetensors.torch.load_file( "/home/kade/source/repos/JTP2/JTP_PILOT2-e3-vit_so400m_patch14_siglip_384.safetensors" ) # Load the state dict with strict=False to ignore missing or unexpected keys missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False) if missing_keys: self.logger.warning(f"Missing keys when loading state_dict: {missing_keys}") if unexpected_keys: self.logger.warning(f"Unexpected keys when loading state_dict: {unexpected_keys}") self.model.to(self.device) if self.device == "cuda" and torch.cuda.get_device_capability()[0] >= 7: self.model = self.model.to(dtype=torch.float16, memory_format=torch.channels_last) self.model.eval() # Load tags with open("/home/kade/source/repos/JTP2/tags.json", "r", encoding="utf-8") as f: self.tags = json.load(f) self.allowed_tags = [tag.replace("_", " ") for tag in self.tags.keys()] def should_process_item(self, item: Path) -> bool: """Check if an image should be processed""" if not item.is_file(): return False if item.suffix.lower() not in self.opts.supported_extensions: return False output_path = item.with_suffix('.tags') return not (self.opts.skip_existing and output_path.exists()) def process_item(self, item: Path) -> None: """Process a single image""" self.logger.info(f"Processing {item.name}...") try: image = Image.open(item) img = image.convert('RGBA') tensor = self.transform(img).unsqueeze(0) if self.device == "cuda": tensor = tensor.cuda() if torch.cuda.get_device_capability()[0] >= 7: tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last) with torch.no_grad(): probits = self.model(tensor)[0].cpu() values, indices = probits.topk(250) # Process results tag_score = { self.allowed_tags[indices[i]]: values[i].item() for i in range(indices.size(0)) } filtered_tags = [ tag for tag, score in sorted(tag_score.items(), key=lambda x: x[1], reverse=True) if score > self.opts.gen_threshold # Use gen_threshold here ] # Save results output_path = item.with_suffix('.tags') with open(output_path, "w", encoding="utf-8") as f: f.write(", ".join(filtered_tags)) except Exception as e: self.logger.error(f"Error processing {item}: {e}") def __del__(self): """Clean up CUDA resources""" if torch.cuda.is_available(): torch.cuda.empty_cache() # Function to process a single image, required to be at the top level for pickling def process_image(img_path: Path): global global_processor if global_processor is None: raise ValueError("Global processor not initialized. Ensure initializer is set correctly.") global_processor.process_item(img_path) def main(): parser = argparse.ArgumentParser(description="Run inference on a directory of images.") parser.add_argument("directory", type=str, help="Target directory containing images.") parser.add_argument("--gen_threshold", type=float, default=0.2, help="Threshold for tag filtering.") parser.add_argument("--char_threshold", type=float, default=0.75, help="Character threshold for tag filtering.") parser.add_argument("--cpu", action="store_true", help="Force CPU inference instead of CUDA") args = parser.parse_args() batch_opts = BatchOptions( batch_size=16, num_workers=6 if not args.cpu else cpu_count() // 2, device="cpu" if args.cpu else "cuda", debug=False, skip_existing=True, recursive=True, supported_extensions={'.jpg', '.jpeg', '.png', '.webp', '.jxl'}, gen_threshold=args.gen_threshold, char_threshold=args.char_threshold ) processor = JTP2Processor(batch_opts) target_path = Path(args.directory) if not target_path.exists(): raise FileNotFoundError(f"Directory not found: {target_path}") image_paths = [] for ext in batch_opts.supported_extensions: image_paths.extend(target_path.rglob(f'*{ext}')) list(processor.process_all(iter(sorted(set(image_paths))))) if __name__ == "__main__": main()