import cv2 import torch import numpy as np from PIL import Image from typing import List, Callable, Optional from functools import partial from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image """ Model wrapper to return a tensor""" class HuggingfaceToTensorModelWrapper(torch.nn.Module): def __init__(self, model): super(HuggingfaceToTensorModelWrapper, self).__init__() self.model = model def forward(self, x): return self.model(x).logits class ClassActivationMap(object): def __init__(self, model, processor): self.model = HuggingfaceToTensorModelWrapper(model) target_layer = model.swinv2.layernorm self.target_layer = [target_layer] self.processor = processor def swinT_reshape_transform_huggingface(self, tensor, width, height): result = tensor.reshape(tensor.size(0), height, width, tensor.size(2)) result = result.transpose(2, 3).transpose(1, 2) return result def run_grad_cam_on_image(self, targets_for_gradcam: List[Callable], reshape_transform: Optional[Callable], input_tensor: torch.nn.Module, input_image: Image, method: Callable=GradCAM): with method(model=self.model, target_layers=self.target_layer, reshape_transform=reshape_transform) as cam: # Replicate the tensor for each of the categories we want to create Grad-CAM for: # print(input_tensor.size()) repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1) # print(repeated_tensor.size()) batch_results = cam(input_tensor=repeated_tensor, targets=targets_for_gradcam) results = [] for grayscale_cam in batch_results: visualization = show_cam_on_image(np.float32(input_image) / 255, grayscale_cam, use_rgb=True) # Make it weight less in the notebook: visualization = cv2.resize(visualization, (visualization.shape[1] // 1, visualization.shape[0] // 1)) results.append(visualization) return np.hstack(results) def get_cam(self, image, category_id): image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width'])) img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze() targets_for_gradcam = [ClassifierOutputTarget(category_id)] reshape_transform = partial(self.swinT_reshape_transform_huggingface, width=img_tensor.shape[2] // 32, height=img_tensor.shape[1] // 32) cam = self.run_grad_cam_on_image(input_tensor=img_tensor, input_image=image, targets_for_gradcam=targets_for_gradcam, reshape_transform=reshape_transform) return cam