autoglaucoma / cam.py
pamixsun's picture
Upload 2 files
0499403
import cv2
import torch
import numpy as np
from PIL import Image
from typing import List, Callable, Optional
from functools import partial
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
""" Model wrapper to return a tensor"""
class HuggingfaceToTensorModelWrapper(torch.nn.Module):
def __init__(self, model):
super(HuggingfaceToTensorModelWrapper, self).__init__()
self.model = model
def forward(self, x):
return self.model(x).logits
class ClassActivationMap(object):
def __init__(self, model, processor):
self.model = HuggingfaceToTensorModelWrapper(model)
target_layer = model.swinv2.layernorm
self.target_layer = [target_layer]
self.processor = processor
def swinT_reshape_transform_huggingface(self, tensor, width, height):
result = tensor.reshape(tensor.size(0),
height,
width,
tensor.size(2))
result = result.transpose(2, 3).transpose(1, 2)
return result
def run_grad_cam_on_image(self,
targets_for_gradcam: List[Callable],
reshape_transform: Optional[Callable],
input_tensor: torch.nn.Module,
input_image: Image,
method: Callable=GradCAM):
with method(model=self.model,
target_layers=self.target_layer,
reshape_transform=reshape_transform) as cam:
# Replicate the tensor for each of the categories we want to create Grad-CAM for:
# print(input_tensor.size())
repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)
# print(repeated_tensor.size())
batch_results = cam(input_tensor=repeated_tensor,
targets=targets_for_gradcam)
results = []
for grayscale_cam in batch_results:
visualization = show_cam_on_image(np.float32(input_image) / 255,
grayscale_cam,
use_rgb=True)
# Make it weight less in the notebook:
visualization = cv2.resize(visualization,
(visualization.shape[1] // 1, visualization.shape[0] // 1))
results.append(visualization)
return np.hstack(results)
def get_cam(self, image, category_id):
image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width']))
img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze()
targets_for_gradcam = [ClassifierOutputTarget(category_id)]
reshape_transform = partial(self.swinT_reshape_transform_huggingface,
width=img_tensor.shape[2] // 32,
height=img_tensor.shape[1] // 32)
cam = self.run_grad_cam_on_image(input_tensor=img_tensor,
input_image=image,
targets_for_gradcam=targets_for_gradcam,
reshape_transform=reshape_transform)
return cam