import sys import os import PIL from PIL import Image import numpy as np from matplotlib import pylab as P import cv2 import torch from torch.utils.data import TensorDataset from torchvision import transforms dirpath_to_modules = './Visual-Explanation-Methods-PyTorch' sys.path.append(dirpath_to_modules) from torchvex.base import ExplanationMethod from torchvex.utils.normalization import clamp_quantile def ShowImage(im, title='', ax=None): image = np.array(im) return image def ShowGrayscaleImage(im, title='', ax=None): if ax is None: P.figure() P.axis('off') P.imshow(im, cmap=P.cm.gray, vmin=0, vmax=1) P.title(title) return P def ShowHeatMap(im, title='', ax=None): im = im - im.min() im = im / im.max() im = im.clip(0,1) im = np.uint8(im * 255) im = cv2.resize(im, (224,224)) image = cv2.resize(im, (224, 224)) # Apply JET colormap color_heatmap = cv2.applyColorMap(image, cv2.COLORMAP_HOT) # P.imshow(im, cmap='inferno') # P.title(title) return color_heatmap def ShowMaskedImage(saliency_map, image, title='', ax=None): """ Save saliency map on image. Args: image: Tensor of size (H,W,3) saliency_map: Tensor of size (H,W,1) """ # if ax is None: # P.figure() # P.axis('off') saliency_map = saliency_map - saliency_map.min() saliency_map = saliency_map / saliency_map.max() saliency_map = saliency_map.clip(0,1) saliency_map = np.uint8(saliency_map * 255) saliency_map = cv2.resize(saliency_map, (224,224)) image = cv2.resize(image, (224, 224)) # Apply JET colormap color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_HOT) # Blend image with heatmap img_with_heatmap = cv2.addWeighted(image, 0.4, color_heatmap, 0.6, 0) # P.imshow(img_with_heatmap) # P.title(title) return img_with_heatmap def LoadImage(file_path): im = PIL.Image.open(file_path) im = im.resize((224, 224)) im = np.asarray(im) return im def visualize_image_grayscale(image_3d, percentile=99): r"""Returns a 3D tensor as a grayscale 2D tensor. This method sums a 3D tensor across the absolute value of axis=2, and then clips values at a given percentile. """ image_2d = np.sum(np.abs(image_3d), axis=2) vmax = np.percentile(image_2d, percentile) vmin = np.min(image_2d) return np.clip((image_2d - vmin) / (vmax - vmin), 0, 1) def visualize_image_diverging(image_3d, percentile=99): r"""Returns a 3D tensor as a 2D tensor with positive and negative values. """ image_2d = np.sum(image_3d, axis=2) span = abs(np.percentile(image_2d, percentile)) vmin = -span vmax = span return np.clip((image_2d - vmin) / (vmax - vmin), -1, 1) class SimpleGradient(ExplanationMethod): def __init__(self, model, create_graph=False, preprocess=None, postprocess=None): super().__init__(model, preprocess, postprocess) self.create_graph = create_graph def predict(self, x): return self.model(x) @torch.enable_grad() def process(self, inputs, target): self.model.zero_grad() inputs.requires_grad_(True) out = self.model(inputs) out = out if type(out) == torch.Tensor else out.logits num_classes = out.size(-1) onehot = torch.zeros(inputs.size(0), num_classes, *target.shape[1:]) onehot = onehot.to(dtype=inputs.dtype, device=inputs.device) onehot.scatter_(1, target.unsqueeze(1), 1) grad, = torch.autograd.grad( (out*onehot).sum(), inputs, create_graph=self.create_graph ) return grad class SmoothGradient(ExplanationMethod): def __init__(self, model, stdev_spread=0.15, num_samples=25, magnitude=True, batch_size=-1, create_graph=False, preprocess=None, postprocess=None): super().__init__(model, preprocess, postprocess) self.stdev_spread = stdev_spread self.nsample = num_samples self.create_graph = create_graph self.magnitude = magnitude self.batch_size = batch_size if self.batch_size == -1: self.batch_size = self.nsample self._simgrad = SimpleGradient(model, create_graph) def process(self, inputs, target): self.model.zero_grad() maxima = inputs.flatten(1).max(-1)[0] minima = inputs.flatten(1).min(-1)[0] stdev = self.stdev_spread * (maxima - minima).cpu() stdev = stdev.view(inputs.size(0), 1, 1, 1).expand_as(inputs) stdev = stdev.unsqueeze(0).expand(self.nsample, *[-1]*4) noise = torch.normal(0, stdev) target_expanded = target.unsqueeze(0).cpu() target_expanded = target_expanded.expand(noise.size(0), -1) noiseloader = torch.utils.data.DataLoader( TensorDataset(noise, target_expanded), batch_size=self.batch_size ) total_gradients = torch.zeros_like(inputs) for noise, t_exp in noiseloader: inputs_w_noise = inputs.unsqueeze(0) + noise.to(inputs.device) inputs_w_noise = inputs_w_noise.view(-1, *inputs.shape[1:]) gradients = self._simgrad(inputs_w_noise, t_exp.view(-1)) gradients = gradients.view(self.batch_size, *inputs.shape) if self.magnitude: gradients = gradients.pow(2) total_gradients = total_gradients + gradients.sum(0) smoothed_gradient = total_gradients / self.nsample return smoothed_gradient def feed_forward(model_name, image, model=None, feature_extractor=None): if model_name in ['ConvNeXt', 'ResNet']: inputs = feature_extractor(image, return_tensors="pt") logits = model(**inputs).logits prediction_class = logits.argmax(-1).item() else: transform_images = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) input_tensor = transform_images(image) inputs = input_tensor.unsqueeze(0) output = model(inputs) prediction_class = output.argmax(-1).item() #prediction_label = model.config.id2label[prediction_class] return inputs, prediction_class def clip_gradient(gradient): gradient = gradient.abs().sum(1, keepdim=True) return clamp_quantile(gradient, q=0.99) def fig2img(fig): """Convert a Matplotlib figure to a PIL Image and return it""" import io buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25): inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor) smoothgrad_gen = SmoothGradient( model, num_samples=num_samples, stdev_spread=0.1, magnitude=False, postprocess=clip_gradient) if type(inputs) != torch.Tensor: inputs = inputs['pixel_values'] smoothgrad_mask = smoothgrad_gen(inputs, prediction_class) smoothgrad_mask = smoothgrad_mask[0].numpy() smoothgrad_mask = np.transpose(smoothgrad_mask, (1, 2, 0)) image = np.asarray(image) # ori_image = ShowImage(image) heat_map_image = ShowHeatMap(smoothgrad_mask) masked_image = ShowMaskedImage(smoothgrad_mask, image) return heat_map_image, masked_image