Spaces:
Runtime error
Runtime error
from PIL import Image | |
from matplotlib.pyplot import imshow, show | |
import matplotlib.pyplot as plt | |
from torchvision import models, transforms | |
from torch.autograd import Variable | |
from torch.nn import functional as F | |
import torch | |
import torch.nn as nn | |
from torch import topk | |
import numpy as np | |
import os | |
import skimage.transform | |
import cv2 | |
import math | |
import openslide | |
import argparse | |
import pickle | |
def show_cam_on_image(img, mask): | |
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) | |
heatmap = np.float32(heatmap) / 255 | |
cam = heatmap + np.float32(img) | |
cam = cam / np.max(cam) | |
return cam | |
def cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s): | |
mask = np.full_like(gray, 0.).astype(np.float32) | |
for ind1, patch in enumerate(patches): | |
x, y = patch.split('.')[0].split('_') | |
x, y = int(x), int(y) | |
#if y <5 or x>w-5 or y>h-5: | |
# continue | |
mask[int(y*h_s):int((y+1)*h_s), int(x*w_s):int((x+1)*w_s)].fill(cam_matrix[ind1][0]) | |
return mask | |
def main(args): | |
label_map = pickle.load(open(os.path.join(args.dataset_metadata_path, 'label_map.pkl'), 'rb')) | |
label_name_from_id = dict() | |
for label_name, label_id in label_map.items(): | |
label_name_from_id[label_id] = label_name | |
n_class = len(label_map)#args.n_class | |
file_name, label = open(args.path_file, 'r').readlines()[-1].split('\t') | |
label = label.rstrip().strip() | |
#site, file_name = file_name.split('/') | |
file_path = os.path.join(args.path_patches, '{}_files/20.0/'.format(file_name)) | |
print(file_name) | |
print(label) | |
p = torch.load('graphcam/prob.pt').cpu().detach().numpy()[0] | |
file_path = os.path.join(args.path_patches, '{}_files/20.0/'.format(file_name)) | |
#ori = openslide.OpenSlide(os.path.join(args.path_WSI, '{}.svs').format(file_name)) | |
ORIGINAL_FILEPATH = os.path.join(args.path_WSI,'TCGA',label, '{}.svs'.format(file_name)) | |
print('L', ORIGINAL_FILEPATH) | |
ori = openslide.OpenSlide(ORIGINAL_FILEPATH) | |
patch_info = open(os.path.join(args.path_graph, file_name, 'c_idx.txt'), 'r') | |
width, height = ori.dimensions | |
REDUCTION_FACTOR = 10 | |
w, h = int(width/512), int(height/512) | |
w_r, h_r = int(width/20), int(height/20) | |
resized_img = ori.get_thumbnail((width,height))#ori.get_thumbnail((w_r,h_r)) | |
resized_img = resized_img.resize((w_r,h_r)) | |
ratio_w, ratio_h = width/resized_img.width, height/resized_img.height | |
print('ratios ', ratio_w, ratio_h) | |
w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR) | |
print(w_s, h_s) | |
patch_info = patch_info.readlines() | |
patches = [] | |
xmax, ymax = 0, 0 | |
for patch in patch_info: | |
x, y = patch.strip('\n').split('\t') | |
if xmax < int(x): xmax = int(x) | |
if ymax < int(y): ymax = int(y) | |
patches.append('{}_{}.jpeg'.format(x,y)) | |
output_img = np.asarray(resized_img)[:,:,::-1].copy() | |
#-----------------------------------------------------------------------------------------------------# | |
# GraphCAM | |
print('visulize GraphCAM') | |
assign_matrix = torch.load('graphcam/s_matrix_ori.pt') | |
m = nn.Softmax(dim=1) | |
assign_matrix = m(assign_matrix) | |
# Thresholding for better visualization | |
p = np.clip(p, 0.4, 1) | |
output_img_copy =np.copy(output_img) | |
gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) | |
image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min()) | |
cam_matrices = [] | |
masks = [] | |
visualizations = [] | |
print(len(patches)) | |
os.makedirs('graphcam_vis', exist_ok=True) | |
for class_i in range(n_class): | |
# Load graphcam for each class | |
cam_matrix = torch.load(f'graphcam/cam_{class_i}.pt') | |
print(cam_matrix.shape) | |
cam_matrix = torch.mm(assign_matrix, cam_matrix.transpose(1,0)) | |
cam_matrix = cam_matrix.cpu() | |
print(assign_matrix.shape) | |
print(cam_matrix.shape) | |
# Normalize the graphcam | |
cam_matrix = (cam_matrix - cam_matrix.min()) / (cam_matrix.max() - cam_matrix.min()) | |
cam_matrix = cam_matrix.detach().numpy() | |
cam_matrix = p[class_i] * cam_matrix | |
cam_matrix = np.clip(cam_matrix, 0, 1) | |
print(cam_matrix.shape) | |
#print() | |
mask = cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s) | |
print('mask shape ', mask.shape) | |
print('imgtf attr ', image_transformer_attribution.shape) | |
vis = show_cam_on_image(image_transformer_attribution, mask) | |
vis = np.uint8(255 * vis) | |
cam_matrices.append(cam_matrix) | |
masks.append(mask) | |
visualizations.append(vis) | |
print() | |
cv2.imwrite('graphcam_vis/{}_all_types_cam_{}.png'.format(file_name, label_name_from_id[class_i] ), vis) | |
h, w, _ = output_img.shape | |
if h > w: | |
vis_merge = cv2.hconcat([output_img] + visualizations) | |
else: | |
vis_merge = cv2.vconcat([output_img] + visualizations) | |
cv2.imwrite('graphcam_vis/{}_all_types_cam_all.png'.format(file_name), vis_merge) | |
cv2.imwrite('graphcam_vis/{}_all_types_ori.png'.format(file_name ), output_img) | |
''' | |
# Load graphcam for differnet class | |
cam_matrix_0 = torch.load('graphcam/cam_0.pt') | |
cam_matrix_0 = torch.mm(assign_matrix, cam_matrix_0.transpose(1,0)) | |
cam_matrix_0 = cam_matrix_0.cpu() | |
cam_matrix_1 = torch.load('graphcam/cam_1.pt') | |
cam_matrix_1 = torch.mm(assign_matrix, cam_matrix_1.transpose(1,0)) | |
cam_matrix_1 = cam_matrix_1.cpu() | |
cam_matrix_2 = torch.load('graphcam/cam_2.pt') | |
cam_matrix_2 = torch.mm(assign_matrix, cam_matrix_2.transpose(1,0)) | |
cam_matrix_2 = cam_matrix_2.cpu() | |
# Normalize the graphcam | |
cam_matrix_0 = (cam_matrix_0 - cam_matrix_0.min()) / (cam_matrix_0.max() - cam_matrix_0.min()) | |
cam_matrix_0 = cam_matrix_0.detach().numpy() | |
cam_matrix_0 = p[0] * cam_matrix_0 | |
cam_matrix_0 = np.clip(cam_matrix_0, 0, 1) | |
cam_matrix_1 = (cam_matrix_1 - cam_matrix_1.min()) / (cam_matrix_1.max() - cam_matrix_1.min()) | |
cam_matrix_1 = cam_matrix_1.detach().numpy() | |
cam_matrix_1 = p[1] * cam_matrix_1 | |
cam_matrix_1 = np.clip(cam_matrix_1, 0, 1) | |
cam_matrix_2 = (cam_matrix_2 - cam_matrix_2.min()) / (cam_matrix_2.max() - cam_matrix_2.min()) | |
cam_matrix_2 = cam_matrix_2.detach().numpy() | |
cam_matrix_2 = p[2] * cam_matrix_2 | |
cam_matrix_2 = np.clip(cam_matrix_2, 0, 1) | |
output_img_copy =np.copy(output_img) | |
gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) | |
image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min()) | |
mask0 = cam_to_mask(gray, patches, cam_matrix_0, w, h, w_s, h_s) | |
vis0 = show_cam_on_image(image_transformer_attribution, mask0) | |
vis0 = np.uint8(255 * vis0) | |
mask1 = cam_to_mask(gray, patches, cam_matrix_1, w, h, w_s, h_s) | |
vis1 = show_cam_on_image(image_transformer_attribution, mask1) | |
vis1 = np.uint8(255 * vis1) | |
mask2 = cam_to_mask(gray, patches, cam_matrix_2, w, h, w_s, h_s) | |
vis2 = show_cam_on_image(image_transformer_attribution, mask2) | |
vis2 = np.uint8(255 * vis2) | |
########################################## | |
h, w, _ = output_img.shape | |
if h > w: | |
vis_merge = cv2.hconcat([output_img, vis0, vis1, vis2]) | |
else: | |
vis_merge = cv2.vconcat([output_img, vis0, vis1, vis2]) | |
#cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_all.png'.format(file_name, site), vis_merge) | |
#cv2.imwrite('graphcam_vis/{}_{}_all_types_ori.png'.format(file_name, site), output_img) | |
#cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_luad.png'.format(file_name, site), vis1) | |
#cv2.imwrite('graphcam_vis/{}_{}_all_types_cam_lscc.png'.format(file_name, site), vis2) | |
cv2.imwrite('graphcam_vis/{}_all_types_cam_all.png'.format(file_name, ), vis_merge) | |
cv2.imwrite('graphcam_vis/{}_all_types_ori.png'.format(file_name ), output_img) | |
cv2.imwrite('graphcam_vis/{}_all_types_cam_luad.png'.format(file_name ), vis1) | |
cv2.imwrite('graphcam_vis/{}_all_types_cam_lscc.png'.format(file_name ), vis2) | |
''' | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='GraphCAM') | |
parser.add_argument('--path_file', type=str, default='test.txt', help='txt file contains test sample') | |
parser.add_argument('--path_patches', type=str, default='', help='') | |
parser.add_argument('--path_WSI', type=str, default='', help='') | |
parser.add_argument('--path_graph', type=str, default='', help='') | |
parser.add_argument('--dataset_metadata_path', type=str, help='Location of the metadata associated with the created dataset: label mapping, splits and so on') | |
args = parser.parse_args() | |
main(args) |