|
import sys |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
|
|
from imagenet_class_indices import CLS2IDX |
|
|
|
sys.path.append("Transformer-Explainability") |
|
|
|
|
|
from baselines.ViT.ViT_explanation_generator import LRP, Baselines |
|
from baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP |
|
from baselines.ViT.ViT_new import vit_base_patch16_224 as vit |
|
|
|
|
|
|
|
def show_cam_on_image(img, mask): |
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) |
|
heatmap = np.float32(heatmap) / 255 |
|
cam = heatmap + np.float32(img) |
|
cam = cam / np.max(cam) |
|
return cam |
|
|
|
|
|
|
|
model = vit_LRP(pretrained=True) |
|
model.eval() |
|
attribution_generator = LRP(model) |
|
model_baseline = vit(pretrained=True) |
|
model_baseline.eval() |
|
baselines_generator = Baselines(model_baseline) |
|
|
|
|
|
def generate_visualization( |
|
original_image, class_index=None, method="transformer_attribution", LRP=True |
|
): |
|
if LRP: |
|
transformer_attribution = attribution_generator.generate_LRP( |
|
original_image.unsqueeze(0), method=method, index=class_index |
|
).detach() |
|
else: |
|
if method == "gradcam": |
|
transformer_attribution = baselines_generator.generate_cam_attn( |
|
original_image.unsqueeze(0), index=class_index |
|
).detach() |
|
else: |
|
transformer_attribution = baselines_generator.generate_rollout( |
|
original_image.unsqueeze(0) |
|
).detach() |
|
if method != "full": |
|
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) |
|
transformer_attribution = torch.nn.functional.interpolate( |
|
transformer_attribution, scale_factor=16, mode="bilinear" |
|
) |
|
else: |
|
transformer_attribution = transformer_attribution.reshape(1, 1, 224, 224) |
|
transformer_attribution = ( |
|
transformer_attribution.reshape(224, 224).data.cpu().numpy() |
|
) |
|
transformer_attribution = ( |
|
transformer_attribution - transformer_attribution.min() |
|
) / (transformer_attribution.max() - transformer_attribution.min()) |
|
|
|
image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy() |
|
image_transformer_attribution = ( |
|
image_transformer_attribution - image_transformer_attribution.min() |
|
) / (image_transformer_attribution.max() - image_transformer_attribution.min()) |
|
vis = show_cam_on_image(image_transformer_attribution, transformer_attribution) |
|
vis = np.uint8(255 * vis) |
|
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
return vis |
|
|
|
|
|
def print_top_classes(predictions, **kwargs): |
|
|
|
prob = torch.softmax(predictions, dim=1) |
|
class_indices = predictions.data.topk(5, dim=1)[1][0].tolist() |
|
max_str_len = 0 |
|
class_names = [] |
|
for cls_idx in class_indices: |
|
class_names.append(CLS2IDX[cls_idx]) |
|
if len(CLS2IDX[cls_idx]) > max_str_len: |
|
max_str_len = len(CLS2IDX[cls_idx]) |
|
|
|
print("Top 5 classes:") |
|
for cls_idx in class_indices: |
|
output_string = "\t{} : {}".format(cls_idx, CLS2IDX[cls_idx]) |
|
output_string += " " * (max_str_len - len(CLS2IDX[cls_idx])) + "\t\t" |
|
output_string += "value = {:.3f}\t prob = {:.1f}%".format( |
|
predictions[0, cls_idx], 100 * prob[0, cls_idx] |
|
) |
|
print(output_string) |
|
|