File size: 5,931 Bytes
a5240f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b67e23
 
a5240f9
 
 
 
 
 
 
 
6b67e23
 
a5240f9
 
 
 
b128442
a5240f9
 
 
b128442
a5240f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd7494
a5240f9
 
59cea8b
 
 
76b7891
a5240f9
b128442
a5240f9
 
 
 
 
 
0272244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b128442
0272244
 
 
 
 
 
 
 
 
 
 
b128442
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from transformers import ViTFeatureExtractor, ViTForImageClassification
import warnings
from torchvision import transforms
from datasets import load_dataset
from pytorch_grad_cam import run_dff_on_image, GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import numpy as np
import cv2 as cv
import torch
from typing import List, Callable, Optional
import logging
from face_grab import FaceGrabber

# original borrowed from https://github.com/jacobgil/pytorch-grad-cam/blob/master/tutorials/HuggingFace.ipynb
# thanks @jacobgil
# further mods beyond this commit by @simonSlamka

warnings.filterwarnings("ignore")

logging.basicConfig(level=logging.INFO)



class HuggingfaceToTensorModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super(HuggingfaceToTensorModelWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        return self.model(x).logits



class GradCam():
    def __init__(self):
        pass
    
    def category_name_to_index(self, model, category_name):
        name_to_index = dict((v, k) for k, v in model.config.id2label.items())
        return name_to_index[category_name]
        
    def run_grad_cam_on_image(self, model: torch.nn.Module,
                            target_layer: torch.nn.Module,
                            targets_for_gradcam: List[Callable],
                            reshape_transform: Optional[Callable],
                            input_tensor: torch.nn.Module,
                            input_image: Image,
                            method: Callable=GradCAM,
                            threshold: float=0.5):
        with method(model=HuggingfaceToTensorModelWrapper(model),
                    target_layers=[target_layer],
                    reshape_transform=reshape_transform) as cam:

            # Replicate the tensor for each of the categories we want to create Grad-CAM for:
            repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)

            batch_results = cam(input_tensor=repeated_tensor,
                                targets=targets_for_gradcam)
            results = []
            for grayscale_cam in batch_results:
                grayscale_cam[grayscale_cam < threshold] = 0
                visualization = show_cam_on_image(np.float32(input_image)/255,
                                                grayscale_cam,
                                                use_rgb=True)
                # Make it weight less in the notebook:
                visualization = cv.resize(visualization,
                                        (visualization.shape[1]//2, visualization.shape[0]//2))
                results.append(visualization)
            return np.hstack(results)
        
        
    def get_top_category(self, model, img_tensor, top_k=5):
        logits = model(img_tensor.unsqueeze(0)).logits
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        topIdx = logits.cpu()[0, :].detach().numpy().argsort()[-1]
        topClass = model.config.id2label[topIdx]
        topScore = probabilities[0][topIdx].item()
        return [{"label": topClass, "score": topScore}]

    def reshape_transform_vit_huggingface(self, x):
        activations = x[:, 1:, :]
        activations = activations.view(activations.shape[0],
                                    14, 14, activations.shape[2])
        activations = activations.transpose(2, 3).transpose(1, 2)
        return activations



if __name__ == "__main__":

    faceGrabber = FaceGrabber()
    gradCam = GradCam()

    image = Image.open("Feature-Image-74.jpg").convert("RGB")
    face = faceGrabber.grab_faces(np.array(image))
    if face is not None:
        image = Image.fromarray(face)

    img_tensor = transforms.ToTensor()(image)

    model = ViTForImageClassification.from_pretrained("ongkn/attraction-classifier")
    targets_for_gradcam = [ClassifierOutputTarget(gradCam.category_name_to_index(model, "pos")),
                        ClassifierOutputTarget(gradCam.category_name_to_index(model, "neg"))]
    target_layer_dff = model.vit.layernorm
    target_layer_gradcam = model.vit.encoder.layer[-2].output
    image_resized = image.resize((224, 224))
    tensor_resized = transforms.ToTensor()(image_resized)

    dff_image = run_dff_on_image(model=model,
                                target_layer=target_layer_dff,
                                classifier=model.classifier,
                                img_pil=image_resized,
                                img_tensor=tensor_resized,
                                reshape_transform=gradCam.reshape_transform_vit_huggingface,
                                n_components=5,
                                top_k=10,
                                threshold=0,
                                output_size=None) #(500, 500))
    cv.namedWindow("DFF Image", cv.WINDOW_KEEPRATIO)
    cv.imshow("DFF Image", cv.cvtColor(dff_image, cv.COLOR_BGR2RGB))
    cv.resizeWindow("DFF Image", 2500, 700)
    # cv.waitKey(0)
    # cv.destroyAllWindows()
    grad_cam_image = gradCam.run_grad_cam_on_image(model=model,
                                        target_layer=target_layer_gradcam,
                                        targets_for_gradcam=targets_for_gradcam,
                                        input_tensor=tensor_resized,
                                        input_image=image_resized,
                                        reshape_transform=gradCam.reshape_transform_vit_huggingface,
                                        threshold=0)
    cv.namedWindow("Grad-CAM Image", cv.WINDOW_KEEPRATIO)
    cv.imshow("Grad-CAM Image", grad_cam_image)
    cv.resizeWindow("Grad-CAM Image", 2000, 1250)
    cv.waitKey(0)
    cv.destroyAllWindows()
    gradCam.print_top_categories(model, tensor_resized)