pamixsun's picture
Upload 2 files
0499403
raw history blame
No virus
3.45 kB
import cv2
import torch
import numpy as np
from PIL import Image
from typing import List, Callable, Optional
from functools import partial
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
""" Model wrapper to return a tensor"""
class HuggingfaceToTensorModelWrapper(torch.nn.Module):
def __init__(self, model):
super(HuggingfaceToTensorModelWrapper, self).__init__()
self.model = model
def forward(self, x):
return self.model(x).logits
class ClassActivationMap(object):
def __init__(self, model, processor):
self.model = HuggingfaceToTensorModelWrapper(model)
target_layer = model.swinv2.layernorm
self.target_layer = [target_layer]
self.processor = processor
def swinT_reshape_transform_huggingface(self, tensor, width, height):
result = tensor.reshape(tensor.size(0),
height,
width,
tensor.size(2))
result = result.transpose(2, 3).transpose(1, 2)
return result
def run_grad_cam_on_image(self,
targets_for_gradcam: List[Callable],
reshape_transform: Optional[Callable],
input_tensor: torch.nn.Module,
input_image: Image,
method: Callable=GradCAM):
with method(model=self.model,
target_layers=self.target_layer,
reshape_transform=reshape_transform) as cam:
# Replicate the tensor for each of the categories we want to create Grad-CAM for:
# print(input_tensor.size())
repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)
# print(repeated_tensor.size())
batch_results = cam(input_tensor=repeated_tensor,
targets=targets_for_gradcam)
results = []
for grayscale_cam in batch_results:
visualization = show_cam_on_image(np.float32(input_image) / 255,
grayscale_cam,
use_rgb=True)
# Make it weight less in the notebook:
visualization = cv2.resize(visualization,
(visualization.shape[1] // 1, visualization.shape[0] // 1))
results.append(visualization)
return np.hstack(results)
def get_cam(self, image, category_id):
image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width']))
img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze()
targets_for_gradcam = [ClassifierOutputTarget(category_id)]
reshape_transform = partial(self.swinT_reshape_transform_huggingface,
width=img_tensor.shape[2] // 32,
height=img_tensor.shape[1] // 32)
cam = self.run_grad_cam_on_image(input_tensor=img_tensor,
input_image=image,
targets_for_gradcam=targets_for_gradcam,
reshape_transform=reshape_transform)
return cam