| """ |
| Explainable AI (XAI) Inference for Nude Multi-Label Classification |
| ================================================================== |
| |
| This script performs inference using a trained Swin Transformer model for |
| multi-label classification of nude images. It also integrates Class Activation |
| Mapping (CAM) to provide visual explanations for the model's predictions. |
| |
| Author: Ramaguru Radhakrishnan |
| Date: March 2025 |
| """ |
|
|
| import torch |
| import torchvision.transforms as transforms |
| from PIL import Image |
| import json |
| from model import SwinTransformerMultiLabel |
| from torchcam.methods import SmoothGradCAMpp |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
|
|
| |
| NUM_CLASSES = 18 |
|
|
| |
| model = SwinTransformerMultiLabel(num_classes=NUM_CLASSES) |
|
|
| |
| checkpoint_path = "../models/multi_nude_detector.pth" |
| checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| model_dict = model.state_dict() |
|
|
| |
| filtered_checkpoint = { |
| k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape |
| } |
| model_dict.update(filtered_checkpoint) |
| model.load_state_dict(model_dict, strict=False) |
|
|
| |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| |
| with open("../data/labels.json", "r") as f: |
| classes = sorted(set(tag for tags in json.load(f).values() for tag in tags)) |
|
|
| |
| if len(classes) != NUM_CLASSES: |
| raise ValueError(f"β Mismatch: Model expects {NUM_CLASSES} classes, but labels.json has {len(classes)} labels!") |
|
|
| |
| img_path = "C:\\Users\\RamaguruRadhakrishna\\Videos\\STAR-main\\STAR-main\\data\\images\\442_.jpeg" |
| image = Image.open(img_path).convert("RGB") |
| input_tensor = transform(image).unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| output = model(input_tensor) |
| print(f"πΉ Model Output Shape: {output.shape}") |
|
|
| |
| predicted_labels = [ |
| classes[i] for i in range(min(len(classes), output.shape[1])) if output[0][i] > 0.5 |
| ] |
| predicted_indices = [i for i in range(output.shape[1]) if output[0][i] > 0.5] |
|
|
| |
| print("β
Predicted Tags:", predicted_labels) |
|
|
| |
| |
| |
|
|
| |
| print(model) |
|
|
| |
| print("π Model Architecture:\n") |
| for name, module in model.named_modules(): |
| print(name) |
|
|
| |
| |
| valid_target_layer = "features.7.3" |
|
|
| |
| if valid_target_layer not in dict(model.named_modules()): |
| raise ValueError(f"β Layer '{valid_target_layer}' not found in model. Choose from:\n{list(dict(model.named_modules()).keys())}") |
|
|
| |
| cam_extractor = SmoothGradCAMpp(model, target_layer=valid_target_layer) |
|
|
| print("β
SmoothGradCAMpp initialized successfully!") |
|
|
| |
| output = model(input_tensor) |
|
|
| |
| for class_idx in predicted_indices: |
| cam = cam_extractor(class_idx, output) |
| cam = cam.squeeze().cpu().numpy() |
| cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam)) |
|
|
| |
| cam_resized = np.array(Image.fromarray(cam * 255).resize(image.size)) |
|
|
| |
| plt.figure(figsize=(6, 6)) |
| plt.imshow(image) |
| plt.imshow(cam_resized, cmap='jet', alpha=0.5) |
| plt.axis("off") |
| plt.title(f"Explainability Heatmap for '{classes[class_idx]}'") |
| plt.show() |
|
|