Spaces:
Sleeping
Sleeping
| import os | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import matplotlib.pyplot as plt | |
| import pylab | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import sys | |
| sys.path.append('ViT_DeiT') | |
| from samples.CLS2IDX import CLS2IDX | |
| from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP | |
| from baselines.ViT.ViT_explanation_generator import LRP | |
| normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| normalize, | |
| ]) | |
| use_thresholding = False | |
| def show_cam_on_image(img, mask): | |
| heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) | |
| heatmap = np.float32(heatmap) / 255 | |
| cam = heatmap + np.float32(img) | |
| cam = cam / np.max(cam) | |
| return cam | |
| # initialize ViT pretrained | |
| model = vit_LRP(pretrained=True) | |
| model.eval() | |
| attribution_generator = LRP(model) | |
| def generate_visualization(original_image, class_index=None): | |
| transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach() | |
| transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) | |
| transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear') | |
| transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy() | |
| transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min()) | |
| if use_thresholding: | |
| transformer_attribution = transformer_attribution * 255 | |
| transformer_attribution = transformer_attribution.astype(np.uint8) | |
| ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| transformer_attribution[transformer_attribution == 255] = 1 | |
| image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy() | |
| image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min()) | |
| vis = show_cam_on_image(image_transformer_attribution, transformer_attribution) | |
| vis = np.uint8(255 * vis) | |
| vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) | |
| return vis | |
| def print_top_classes(original_image, **kwargs): | |
| predictions = model(original_image.unsqueeze(0)) | |
| # Print Top-5 predictions | |
| prob = torch.softmax(predictions, dim=1) | |
| class_indices = predictions.data.topk(5, dim=1)[1][0].tolist() | |
| max_str_len = 0 | |
| class_names = [] | |
| output = [] | |
| for cls_idx in class_indices: | |
| class_names.append(CLS2IDX[cls_idx]) | |
| if len(CLS2IDX[cls_idx]) > max_str_len: | |
| max_str_len = len(CLS2IDX[cls_idx]) | |
| for cls_idx in class_indices: | |
| output_string = '{} : {}'.format(cls_idx, CLS2IDX[cls_idx]) | |
| # output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t' | |
| output_string += ' value = {:.3f} prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx]) | |
| output.append(output_string) | |
| return output |