Spaces:
Build error
Build error
import PIL | |
from PIL import Image | |
import numpy as np | |
from matplotlib import pylab as P | |
import cv2 | |
import torch | |
from torch.utils.data import TensorDataset | |
from torchvision import transforms | |
# dirpath_to_modules = './Visual-Explanation-Methods-PyTorch' | |
# sys.path.append(dirpath_to_modules) | |
from torchvex.base import ExplanationMethod | |
from torchvex.utils.normalization import clamp_quantile | |
def ShowImage(im, title='', ax=None): | |
image = np.array(im) | |
return image | |
def ShowGrayscaleImage(im, title='', ax=None): | |
if ax is None: | |
P.figure() | |
P.axis('off') | |
P.imshow(im, cmap=P.cm.gray, vmin=0, vmax=1) | |
P.title(title) | |
return P | |
def ShowHeatMap(im, title='', ax=None): | |
im = im - im.min() | |
im = im / im.max() | |
im = im.clip(0,1) | |
im = np.uint8(im * 255) | |
im = cv2.resize(im, (224,224)) | |
image = cv2.resize(im, (224, 224)) | |
# Apply JET colormap | |
color_heatmap = cv2.applyColorMap(image, cv2.COLORMAP_HOT) | |
# P.imshow(im, cmap='inferno') | |
# P.title(title) | |
return color_heatmap | |
def ShowMaskedImage(saliency_map, image, title='', ax=None): | |
""" | |
Save saliency map on image. | |
Args: | |
image: Tensor of size (H,W,3) | |
saliency_map: Tensor of size (H,W,1) | |
""" | |
# if ax is None: | |
# P.figure() | |
# P.axis('off') | |
saliency_map = saliency_map - saliency_map.min() | |
saliency_map = saliency_map / saliency_map.max() | |
saliency_map = saliency_map.clip(0,1) | |
saliency_map = np.uint8(saliency_map * 255) | |
saliency_map = cv2.resize(saliency_map, (224,224)) | |
image = cv2.resize(image, (224, 224)) | |
# Apply JET colormap | |
color_heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_HOT) | |
# Blend image with heatmap | |
img_with_heatmap = cv2.addWeighted(image, 0.4, color_heatmap, 0.6, 0) | |
# P.imshow(img_with_heatmap) | |
# P.title(title) | |
return img_with_heatmap | |
def LoadImage(file_path): | |
im = PIL.Image.open(file_path) | |
im = im.resize((224, 224)) | |
im = np.asarray(im) | |
return im | |
def visualize_image_grayscale(image_3d, percentile=99): | |
r"""Returns a 3D tensor as a grayscale 2D tensor. | |
This method sums a 3D tensor across the absolute value of axis=2, and then | |
clips values at a given percentile. | |
""" | |
image_2d = np.sum(np.abs(image_3d), axis=2) | |
vmax = np.percentile(image_2d, percentile) | |
vmin = np.min(image_2d) | |
return np.clip((image_2d - vmin) / (vmax - vmin), 0, 1) | |
def visualize_image_diverging(image_3d, percentile=99): | |
r"""Returns a 3D tensor as a 2D tensor with positive and negative values. | |
""" | |
image_2d = np.sum(image_3d, axis=2) | |
span = abs(np.percentile(image_2d, percentile)) | |
vmin = -span | |
vmax = span | |
return np.clip((image_2d - vmin) / (vmax - vmin), -1, 1) | |
class SimpleGradient(ExplanationMethod): | |
def __init__(self, model, create_graph=False, | |
preprocess=None, postprocess=None): | |
super().__init__(model, preprocess, postprocess) | |
self.create_graph = create_graph | |
def predict(self, x): | |
return self.model(x) | |
def process(self, inputs, target): | |
self.model.zero_grad() | |
inputs.requires_grad_(True) | |
out = self.model(inputs) | |
out = out if type(out) == torch.Tensor else out.logits | |
num_classes = out.size(-1) | |
onehot = torch.zeros(inputs.size(0), num_classes, *target.shape[1:]) | |
onehot = onehot.to(dtype=inputs.dtype, device=inputs.device) | |
onehot.scatter_(1, target.unsqueeze(1), 1) | |
grad, = torch.autograd.grad( | |
(out*onehot).sum(), inputs, create_graph=self.create_graph | |
) | |
return grad | |
class SmoothGradient(ExplanationMethod): | |
def __init__(self, model, stdev_spread=0.15, num_samples=25, | |
magnitude=True, batch_size=-1, | |
create_graph=False, preprocess=None, postprocess=None): | |
super().__init__(model, preprocess, postprocess) | |
self.stdev_spread = stdev_spread | |
self.nsample = num_samples | |
self.create_graph = create_graph | |
self.magnitude = magnitude | |
self.batch_size = batch_size | |
if self.batch_size == -1: | |
self.batch_size = self.nsample | |
self._simgrad = SimpleGradient(model, create_graph) | |
def process(self, inputs, target): | |
self.model.zero_grad() | |
maxima = inputs.flatten(1).max(-1)[0] | |
minima = inputs.flatten(1).min(-1)[0] | |
stdev = self.stdev_spread * (maxima - minima).cpu() | |
stdev = stdev.view(inputs.size(0), 1, 1, 1).expand_as(inputs) | |
stdev = stdev.unsqueeze(0).expand(self.nsample, *[-1]*4) | |
noise = torch.normal(0, stdev) | |
target_expanded = target.unsqueeze(0).cpu() | |
target_expanded = target_expanded.expand(noise.size(0), -1) | |
noiseloader = torch.utils.data.DataLoader( | |
TensorDataset(noise, target_expanded), batch_size=self.batch_size | |
) | |
total_gradients = torch.zeros_like(inputs) | |
for noise, t_exp in noiseloader: | |
inputs_w_noise = inputs.unsqueeze(0) + noise.to(inputs.device) | |
inputs_w_noise = inputs_w_noise.view(-1, *inputs.shape[1:]) | |
gradients = self._simgrad(inputs_w_noise, t_exp.view(-1)) | |
gradients = gradients.view(self.batch_size, *inputs.shape) | |
if self.magnitude: | |
gradients = gradients.pow(2) | |
total_gradients = total_gradients + gradients.sum(0) | |
smoothed_gradient = total_gradients / self.nsample | |
return smoothed_gradient | |
def feed_forward(model_name, image, model=None, feature_extractor=None): | |
if model_name in ['ConvNeXt', 'ResNet']: | |
inputs = feature_extractor(image, return_tensors="pt") | |
logits = model(**inputs).logits | |
prediction_class = logits.argmax(-1).item() | |
else: | |
transform_images = transforms.Compose([ | |
transforms.Resize(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
input_tensor = transform_images(image) | |
inputs = input_tensor.unsqueeze(0) | |
output = model(inputs) | |
prediction_class = output.argmax(-1).item() | |
#prediction_label = model.config.id2label[prediction_class] | |
return inputs, prediction_class | |
def clip_gradient(gradient): | |
gradient = gradient.abs().sum(1, keepdim=True) | |
return clamp_quantile(gradient, q=0.99) | |
def fig2img(fig): | |
"""Convert a Matplotlib figure to a PIL Image and return it""" | |
import io | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25, return_mask=False): | |
inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor) | |
smoothgrad_gen = SmoothGradient( | |
model, num_samples=num_samples, stdev_spread=0.1, | |
magnitude=False, postprocess=clip_gradient) | |
if type(inputs) != torch.Tensor: | |
inputs = inputs['pixel_values'] | |
smoothgrad_mask = smoothgrad_gen(inputs, prediction_class) | |
smoothgrad_mask = smoothgrad_mask[0].numpy() | |
smoothgrad_mask = np.transpose(smoothgrad_mask, (1, 2, 0)) | |
image = np.asarray(image) | |
# ori_image = ShowImage(image) | |
heat_map_image = ShowHeatMap(smoothgrad_mask) | |
masked_image = ShowMaskedImage(smoothgrad_mask, image) | |
if return_mask: | |
return heat_map_image, masked_image, smoothgrad_mask | |
else: | |
return heat_map_image, masked_image | |