AioMedica2 / src /vis_graphcam.py
chris1nexus
First commit
54660f7
raw history blame
No virus
8.63 kB
from PIL import Image
from matplotlib.pyplot import imshow, show
import matplotlib.pyplot as plt
from torchvision import models, transforms
from torch.autograd import Variable
from torch.nn import functional as F
import torch
import torch.nn as nn
from torch import topk
import numpy as np
import os
import skimage.transform
import cv2
import math
import openslide
import argparse
import pickle
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
def cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s):
mask = np.full_like(gray, 0.).astype(np.float32)
for ind1, patch in enumerate(patches):
x, y = patch.split('.')[0].split('_')
x, y = int(x), int(y)
#if y <5 or x>w-5 or y>h-5:
# continue
mask[int(y*h_s):int((y+1)*h_s), int(x*w_s):int((x+1)*w_s)].fill(cam_matrix[ind1][0])
return mask
def main(args):
label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb'))
label_name_from_id = dict()
for label_name, label_id in label_map.items():
label_name_from_id[label_id] = label_name
n_class = len(label_map)#args.n_class
file_name, label = open(args.path_file, 'r').readlines()[-1].split('\t')
label = label.rstrip().strip()
#site, file_name = file_name.split('/')
file_path = os.path.join(args.path_patches, '{}_files/20.0/'.format(file_name))
print(file_name)
print(label)
p = torch.load('graphcam/prob.pt').cpu().detach().numpy()[0]
file_path = os.path.join(args.path_patches, '{}_files/20.0/'.format(file_name))
#ori = openslide.OpenSlide(os.path.join(args.path_WSI, '{}.svs').format(file_name))
ORIGINAL_FILEPATH = os.path.join(args.path_WSI,'TCGA',label, '{}.svs'.format(file_name))
print('L', ORIGINAL_FILEPATH)
ori = openslide.OpenSlide(ORIGINAL_FILEPATH)
patch_info = open(os.path.join(args.path_graph, file_name, 'c_idx.txt'), 'r')
width, height = ori.dimensions
REDUCTION_FACTOR = 10
w, h = int(width/512), int(height/512)
w_r, h_r = int(width/20), int(height/20)
resized_img = ori.get_thumbnail((width,height))#ori.get_thumbnail((w_r,h_r))
resized_img = resized_img.resize((w_r,h_r))
ratio_w, ratio_h = width/resized_img.width, height/resized_img.height
print('ratios ', ratio_w, ratio_h)
w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR)
print(w_s, h_s)
patch_info = patch_info.readlines()
patches = []
xmax, ymax = 0, 0
for patch in patch_info:
x, y = patch.strip('\n').split('\t')
if xmax < int(x): xmax = int(x)
if ymax < int(y): ymax = int(y)
patches.append('{}_{}.jpeg'.format(x,y))
output_img = np.asarray(resized_img)[:,:,::-1].copy()
#-----------------------------------------------------------------------------------------------------#
# GraphCAM
print('visulize GraphCAM')
assign_matrix = torch.load('graphcam/s_matrix_ori.pt')
m = nn.Softmax(dim=1)
assign_matrix = m(assign_matrix)
# Thresholding for better visualization
p = np.clip(p, 0.4, 1)
output_img_copy =np.copy(output_img)
gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min())
cam_matrices = []
masks = []
visualizations = []
print(len(patches))
os.makedirs('graphcam_vis', exist_ok=True)
for class_i in range(n_class):
# Load graphcam for each class
cam_matrix = torch.load(f'graphcam/cam_{class_i}.pt')
print(cam_matrix.shape)
cam_matrix = torch.mm(assign_matrix, cam_matrix.transpose(1,0))
cam_matrix = cam_matrix.cpu()
print(assign_matrix.shape)
print(cam_matrix.shape)
# Normalize the graphcam
cam_matrix = (cam_matrix - cam_matrix.min()) / (cam_matrix.max() - cam_matrix.min())
cam_matrix = cam_matrix.detach().numpy()
cam_matrix = p[class_i] * cam_matrix
cam_matrix = np.clip(cam_matrix, 0, 1)
print(cam_matrix.shape)
#print()
mask = cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s)
print('mask shape ', mask.shape)
print('imgtf attr ', image_transformer_attribution.shape)
vis = show_cam_on_image(image_transformer_attribution, mask)
vis = np.uint8(255 * vis)
cam_matrices.append(cam_matrix)
masks.append(mask)
visualizations.append(vis)
print()
cv2.imwrite('graphcam_vis/{}_all_types_cam_{}.png'.format(file_name, label_name_from_id[class_i] ), vis)
h, w, _ = output_img.shape
if h > w:
vis_merge = cv2.hconcat([output_img] + visualizations)
else:
vis_merge = cv2.vconcat([output_img] + visualizations)
cv2.imwrite('graphcam_vis/{}_all_types_cam_all.png'.format(file_name), vis_merge)
cv2.imwrite('graphcam_vis/{}_all_types_ori.png'.format(file_name ), output_img)
'''
# Load graphcam for differnet class
cam_matrix_0 = torch.load('graphcam/cam_0.pt')
cam_matrix_0 = torch.mm(assign_matrix, cam_matrix_0.transpose(1,0))
cam_matrix_0 = cam_matrix_0.cpu()
cam_matrix_1 = torch.load('graphcam/cam_1.pt')
cam_matrix_1 = torch.mm(assign_matrix, cam_matrix_1.transpose(1,0))
cam_matrix_1 = cam_matrix_1.cpu()
cam_matrix_2 = torch.load('graphcam/cam_2.pt')
cam_matrix_2 = torch.mm(assign_matrix, cam_matrix_2.transpose(1,0))
cam_matrix_2 = cam_matrix_2.cpu()
# Normalize the graphcam
cam_matrix_0 = (cam_matrix_0 - cam_matrix_0.min()) / (cam_matrix_0.max() - cam_matrix_0.min())
cam_matrix_0 = cam_matrix_0.detach().numpy()
cam_matrix_0 = p[0] * cam_matrix_0
cam_matrix_0 = np.clip(cam_matrix_0, 0, 1)
cam_matrix_1 = (cam_matrix_1 - cam_matrix_1.min()) / (cam_matrix_1.max() - cam_matrix_1.min())
cam_matrix_1 = cam_matrix_1.detach().numpy()
cam_matrix_1 = p[1] * cam_matrix_1
cam_matrix_1 = np.clip(cam_matrix_1, 0, 1)
cam_matrix_2 = (cam_matrix_2 - cam_matrix_2.min()) / (cam_matrix_2.max() - cam_matrix_2.min())
cam_matrix_2 = cam_matrix_2.detach().numpy()
cam_matrix_2 = p[2] * cam_matrix_2
cam_matrix_2 = np.clip(cam_matrix_2, 0, 1)
output_img_copy =np.copy(output_img)
gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min())
mask0 = cam_to_mask(gray, patches, cam_matrix_0, w, h, w_s, h_s)
vis0 = show_cam_on_image(image_transformer_attribution, mask0)
vis0 = np.uint8(255 * vis0)
mask1 = cam_to_mask(gray, patches, cam_matrix_1, w, h, w_s, h_s)
vis1 = show_cam_on_image(image_transformer_attribution, mask1)
vis1 = np.uint8(255 * vis1)
mask2 = cam_to_mask(gray, patches, cam_matrix_2, w, h, w_s, h_s)
vis2 = show_cam_on_image(image_transformer_attribution, mask2)
vis2 = np.uint8(255 * vis2)
##########################################
h, w, _ = output_img.shape
if h > w:
vis_merge = cv2.hconcat([output_img, vis0, vis1, vis2])
else:
vis_merge = cv2.vconcat([output_img, vis0, vis1, vis2])
#cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_all.png'.format(file_name, site), vis_merge)
#cv2.imwrite('graphcam_vis/{}_{}_all_types_ori.png'.format(file_name, site), output_img)
#cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_luad.png'.format(file_name, site), vis1)
#cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_lscc.png'.format(file_name, site), vis2)
cv2.imwrite('graphcam_vis/{}_all_types_cam_all.png'.format(file_name, ), vis_merge)
cv2.imwrite('graphcam_vis/{}_all_types_ori.png'.format(file_name ), output_img)
cv2.imwrite('graphcam_vis/{}_all_types_cam_luad.png'.format(file_name ), vis1)
cv2.imwrite('graphcam_vis/{}_all_types_cam_lscc.png'.format(file_name ), vis2)
'''
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GraphCAM')
parser.add_argument('--path_file', type=str, default='test.txt', help='txt file contains test sample')
parser.add_argument('--path_patches', type=str, default='', help='')
parser.add_argument('--path_WSI', type=str, default='', help='')
parser.add_argument('--path_graph', type=str, default='', help='')
parser.add_argument('--dataset_metadata_path', type=str, help='Location of the metadata associated with the created dataset: label mapping, splits and so on')
args = parser.parse_args()
main(args)