explain-ViT / visualization.py
WwYc's picture
Update visualization.py
91913b5 verified
import os
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import pylab
import torch
import numpy as np
import cv2
import sys
sys.path.append('ViT_DeiT')
from samples.CLS2IDX import CLS2IDX
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP
from baselines.ViT.ViT_explanation_generator import LRP
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
use_thresholding = False
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
# initialize ViT pretrained
model = vit_LRP(pretrained=True)
model.eval()
attribution_generator = LRP(model)
def generate_visualization(original_image, class_index=None):
transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach()
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
if use_thresholding:
transformer_attribution = transformer_attribution * 255
transformer_attribution = transformer_attribution.astype(np.uint8)
ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
transformer_attribution[transformer_attribution == 255] = 1
image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
def print_top_classes(original_image, **kwargs):
predictions = model(original_image.unsqueeze(0))
# Print Top-5 predictions
prob = torch.softmax(predictions, dim=1)
class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
max_str_len = 0
class_names = []
output = []
for cls_idx in class_indices:
class_names.append(CLS2IDX[cls_idx])
if len(CLS2IDX[cls_idx]) > max_str_len:
max_str_len = len(CLS2IDX[cls_idx])
for cls_idx in class_indices:
output_string = '{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
# output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
output_string += ' value = {:.3f} prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
output.append(output_string)
return output