File size: 3,454 Bytes
0499403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

import cv2
import torch

import numpy as np
from PIL import Image
from typing import List, Callable, Optional
from functools import partial

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image


""" Model wrapper to return a tensor"""
class HuggingfaceToTensorModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(HuggingfaceToTensorModelWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        return self.model(x).logits


class ClassActivationMap(object):
    def __init__(self, model, processor):
        self.model = HuggingfaceToTensorModelWrapper(model)
        target_layer = model.swinv2.layernorm
        self.target_layer = [target_layer]
        self.processor = processor

    def swinT_reshape_transform_huggingface(self, tensor, width, height):
        result = tensor.reshape(tensor.size(0),
                                height,
                                width,
                                tensor.size(2))
        result = result.transpose(2, 3).transpose(1, 2)
        return result

    def run_grad_cam_on_image(self, 
                              targets_for_gradcam: List[Callable],
                              reshape_transform: Optional[Callable],
                              input_tensor: torch.nn.Module,
                              input_image: Image,
                              method: Callable=GradCAM):
        with method(model=self.model,
                    target_layers=self.target_layer,
                    reshape_transform=reshape_transform) as cam:

            # Replicate the tensor for each of the categories we want to create Grad-CAM for:
            # print(input_tensor.size())
            repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)
            # print(repeated_tensor.size())

            batch_results = cam(input_tensor=repeated_tensor,
                                targets=targets_for_gradcam)
            results = []
            for grayscale_cam in batch_results:
                visualization = show_cam_on_image(np.float32(input_image) / 255,
                                                grayscale_cam,
                                                use_rgb=True)
                # Make it weight less in the notebook:
                visualization = cv2.resize(visualization,
                                        (visualization.shape[1] // 1, visualization.shape[0] // 1))
                results.append(visualization)
            return np.hstack(results)
        
    def get_cam(self, image, category_id):
        image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width']))
        img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze()
        targets_for_gradcam = [ClassifierOutputTarget(category_id)]
        reshape_transform = partial(self.swinT_reshape_transform_huggingface,
                                    width=img_tensor.shape[2] // 32,
                                    height=img_tensor.shape[1] // 32)
        cam = self.run_grad_cam_on_image(input_tensor=img_tensor,
                      input_image=image,
                      targets_for_gradcam=targets_for_gradcam,
                      reshape_transform=reshape_transform)
        
        return cam