Duckin's picture
Upload XAI.py
b92d56c
raw
history blame
3.3 kB
import torch
import numpy as np
import cv2
import torchvision.transforms as transforms
CLS2IDX= {0: "benign",
1: "ductal carcinoma",
2: "lobular carcinoma",
3: "mucinous carcinoma",
4: "papillary carcinoma"}
def tf(img):
mean = [0.485, 0.456, 0.406] # our images have three channels
std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=mean, std=std)
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
normalize,
])
return transform(img)
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
def generate_visualization(original_image, attribution_generator, use_thresholding=False, class_index=None):
transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach()
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
if use_thresholding:
transformer_attribution = transformer_attribution * 255
transformer_attribution = transformer_attribution.astype(np.uint8)
ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
transformer_attribution[transformer_attribution == 255] = 1
image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
def print_top_classes(predictions, **kwargs):
# Print Top-5 predictions
prob = torch.softmax(predictions, dim=1)
class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()
# max_str_len = 0
# class_names = []
# for cls_idx in class_indices:
# class_names.append(CLS2IDX[cls_idx])
# if len(CLS2IDX[cls_idx]) > max_str_len:
# max_str_len = len(CLS2IDX[cls_idx])
dt = dict()
# print('Top 5 classes:')
for cls_idx in class_indices:
# output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
# output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
# output_string += 'prob = {:.1f}%'.format(100 * prob[0, cls_idx])
# output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])
# print(output_string)
dt.update({CLS2IDX[cls_idx] : float('{:.2f}'.format(100 * prob[0, cls_idx]))})
return dt