|
import torch |
|
import numpy as np |
|
import cv2 |
|
import torchvision.transforms as transforms |
|
|
|
CLS2IDX= {0: "benign", |
|
1: "ductal carcinoma", |
|
2: "lobular carcinoma", |
|
3: "mucinous carcinoma", |
|
4: "papillary carcinoma"} |
|
|
|
def tf(img): |
|
mean = [0.485, 0.456, 0.406] |
|
std = [0.229, 0.224, 0.225] |
|
normalize = transforms.Normalize(mean=mean, std=std) |
|
transform = transforms.Compose([ |
|
transforms.Resize((224,224)), |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
return transform(img) |
|
|
|
def show_cam_on_image(img, mask): |
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) |
|
heatmap = np.float32(heatmap) / 255 |
|
cam = heatmap + np.float32(img) |
|
cam = cam / np.max(cam) |
|
return cam |
|
|
|
def generate_visualization(original_image, attribution_generator, use_thresholding=False, class_index=None): |
|
transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0), method="transformer_attribution", index=class_index).detach() |
|
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14) |
|
transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear') |
|
transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy() |
|
transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min()) |
|
|
|
if use_thresholding: |
|
transformer_attribution = transformer_attribution * 255 |
|
transformer_attribution = transformer_attribution.astype(np.uint8) |
|
ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
|
transformer_attribution[transformer_attribution == 255] = 1 |
|
|
|
image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy() |
|
image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min()) |
|
vis = show_cam_on_image(image_transformer_attribution, transformer_attribution) |
|
vis = np.uint8(255 * vis) |
|
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR) |
|
return vis |
|
|
|
def print_top_classes(predictions, **kwargs): |
|
|
|
prob = torch.softmax(predictions, dim=1) |
|
class_indices = predictions.data.topk(5, dim=1)[1][0].tolist() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dt = dict() |
|
|
|
for cls_idx in class_indices: |
|
|
|
|
|
|
|
|
|
|
|
dt.update({CLS2IDX[cls_idx] : float('{:.2f}'.format(100 * prob[0, cls_idx]))}) |
|
return dt |