File size: 7,592 Bytes
fcc16aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import PIL
from PIL import Image
import numpy as np
from matplotlib import pylab as P
import cv2

import torch
from torch.utils.data import TensorDataset
from torchvision import transforms

# dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
# sys.path.append(dirpath_to_modules)

from torchvex.base import ExplanationMethod
from torchvex.utils.normalization import clamp_quantile

def ShowImage(im, title='', ax=None):
    image = np.array(im)
    return image

def ShowGrayscaleImage(im, title='', ax=None):
    if ax is None:
        P.figure()
    P.axis('off')
    P.imshow(im, cmap=P.cm.gray, vmin=0, vmax=1)
    P.title(title)
    return P

def ShowHeatMap(im, title='', ax=None):
    im = im - im.min()
    im = im / im.max()
    im = im.clip(0,1)
    im = np.uint8(im * 255)
    
    im = cv2.resize(im, (224,224))
    image = cv2.resize(im, (224, 224))

    # Apply JET colormap
    color_heatmap = cv2.applyColorMap(image, cv2.COLORMAP_HOT)
    # P.imshow(im, cmap='inferno')
    # P.title(title)
    return color_heatmap
    
def ShowMaskedImage(saliency_map, image, title='', ax=None):
    """ 
    Save saliency map on image.
    
    Args:
        image: Tensor of size (H,W,3)
        saliency_map: Tensor of size (H,W,1)
    """
    
    # if ax is None:
    #     P.figure()
    # P.axis('off')

    saliency_map = saliency_map - saliency_map.min()
    saliency_map = saliency_map / saliency_map.max()
    saliency_map = saliency_map.clip(0,1)
    saliency_map = np.uint8(saliency_map * 255)
    
    saliency_map = cv2.resize(saliency_map, (224,224))
    image = cv2.resize(image, (224, 224))

    # Apply JET colormap
    color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_HOT)
    
    # Blend image with heatmap
    img_with_heatmap = cv2.addWeighted(image, 0.4, color_heatmap, 0.6, 0)

    # P.imshow(img_with_heatmap)
    # P.title(title)
    return img_with_heatmap

def LoadImage(file_path):
    im = PIL.Image.open(file_path)
    im = im.resize((224, 224))
    im = np.asarray(im)
    return im


def visualize_image_grayscale(image_3d, percentile=99):
    r"""Returns a 3D tensor as a grayscale 2D tensor.
    This method sums a 3D tensor across the absolute value of axis=2, and then
    clips values at a given percentile.
    """
    image_2d = np.sum(np.abs(image_3d), axis=2)

    vmax = np.percentile(image_2d, percentile)
    vmin = np.min(image_2d)

    return np.clip((image_2d - vmin) / (vmax - vmin), 0, 1)

def visualize_image_diverging(image_3d, percentile=99):
    r"""Returns a 3D tensor as a 2D tensor with positive and negative values.
    """
    image_2d = np.sum(image_3d, axis=2)

    span = abs(np.percentile(image_2d, percentile))
    vmin = -span
    vmax = span

    return np.clip((image_2d - vmin) / (vmax - vmin), -1, 1)


class SimpleGradient(ExplanationMethod):
    def __init__(self, model, create_graph=False,
                 preprocess=None, postprocess=None):
        super().__init__(model, preprocess, postprocess)
        self.create_graph = create_graph

    def predict(self, x):
        return self.model(x)

    @torch.enable_grad()
    def process(self, inputs, target):
        self.model.zero_grad()
        inputs.requires_grad_(True)

        out = self.model(inputs)
        out = out if type(out) == torch.Tensor else out.logits

        num_classes = out.size(-1)
        onehot = torch.zeros(inputs.size(0), num_classes, *target.shape[1:])
        onehot = onehot.to(dtype=inputs.dtype, device=inputs.device)
        onehot.scatter_(1, target.unsqueeze(1), 1)

        grad, = torch.autograd.grad(
            (out*onehot).sum(), inputs, create_graph=self.create_graph
        )

        return grad


class SmoothGradient(ExplanationMethod):
    def __init__(self, model, stdev_spread=0.15, num_samples=25,
                 magnitude=True, batch_size=-1,
                 create_graph=False, preprocess=None, postprocess=None):
        super().__init__(model, preprocess, postprocess)
        self.stdev_spread = stdev_spread
        self.nsample = num_samples
        self.create_graph = create_graph
        self.magnitude = magnitude
        self.batch_size = batch_size
        if self.batch_size == -1:
            self.batch_size = self.nsample

        self._simgrad = SimpleGradient(model, create_graph)

    def process(self, inputs, target):
        self.model.zero_grad()

        maxima = inputs.flatten(1).max(-1)[0]
        minima = inputs.flatten(1).min(-1)[0]

        stdev = self.stdev_spread * (maxima - minima).cpu()
        stdev = stdev.view(inputs.size(0), 1, 1, 1).expand_as(inputs)
        stdev = stdev.unsqueeze(0).expand(self.nsample, *[-1]*4)
        noise = torch.normal(0, stdev)

        target_expanded = target.unsqueeze(0).cpu()
        target_expanded = target_expanded.expand(noise.size(0), -1)

        noiseloader = torch.utils.data.DataLoader(
            TensorDataset(noise, target_expanded), batch_size=self.batch_size
        )

        total_gradients = torch.zeros_like(inputs)
        for noise, t_exp in noiseloader:
            inputs_w_noise = inputs.unsqueeze(0) + noise.to(inputs.device)
            inputs_w_noise = inputs_w_noise.view(-1, *inputs.shape[1:])
            gradients = self._simgrad(inputs_w_noise, t_exp.view(-1))
            gradients = gradients.view(self.batch_size, *inputs.shape)
            if self.magnitude:
                gradients = gradients.pow(2)
            total_gradients = total_gradients + gradients.sum(0)

        smoothed_gradient = total_gradients / self.nsample
        return smoothed_gradient


def feed_forward(model_name, image, model=None, feature_extractor=None):
    if model_name in ['ConvNeXt', 'ResNet']:
        inputs = feature_extractor(image, return_tensors="pt")
        logits = model(**inputs).logits
        prediction_class = logits.argmax(-1).item()
    else:
        transform_images = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        input_tensor = transform_images(image)
        inputs = input_tensor.unsqueeze(0)

        output = model(inputs)
        prediction_class = output.argmax(-1).item()
    #prediction_label = model.config.id2label[prediction_class]
    return inputs, prediction_class

def clip_gradient(gradient):
    gradient = gradient.abs().sum(1, keepdim=True)
    return clamp_quantile(gradient, q=0.99)

def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25, return_mask=False):
    inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)

    smoothgrad_gen = SmoothGradient(
        model, num_samples=num_samples, stdev_spread=0.1,
        magnitude=False, postprocess=clip_gradient)

    if type(inputs) != torch.Tensor:
        inputs = inputs['pixel_values']

    smoothgrad_mask = smoothgrad_gen(inputs, prediction_class)
    smoothgrad_mask = smoothgrad_mask[0].numpy()
    smoothgrad_mask = np.transpose(smoothgrad_mask, (1, 2, 0))

    image = np.asarray(image)
    # ori_image = ShowImage(image)
    heat_map_image = ShowHeatMap(smoothgrad_mask)
    masked_image = ShowMaskedImage(smoothgrad_mask, image)

    if return_mask:
        return heat_map_image, masked_image, smoothgrad_mask
    else:
        return heat_map_image, masked_image