from __future__ import absolute_import, division, print_function import os import numpy as np import torch import torch.nn as nn from torchvision import transforms import torchvision.models as models from feature_extractor import cl from models.GraphTransformer import Classifier from models.weight_init import weight_init from feature_extractor.build_graph_utils import ToTensor, Compose, bag_dataset, adj_matrix import torchvision.transforms.functional as VF from src.vis_graphcam import show_cam_on_image,cam_to_mask from easydict import EasyDict as edict from models.GraphTransformer import Classifier from slide_tiling import save_tiles import pickle from collections import OrderedDict import glob import openslide import numpy as np import skimage.transform import cv2 class Predictor: def __init__(self): self.classdict = pickle.load(open(os.environ['CLASS_METADATA'], 'rb' )) self.label_map_inv = dict() for label_name, label_id in self.classdict.items(): self.label_map_inv[label_id] = label_name iclf_weights = os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] graph_transformer_weights = os.environ['GT_WEIGHT_PATH'] self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.__init_iclf(iclf_weights, backbone='resnet18') self.__init_graph_transformer(graph_transformer_weights) def predict(self, slide_path): # get tiles for a given WSI slide save_tiles(slide_path) filename = os.path.basename(slide_path) FILEID = filename.rsplit('.', maxsplit=1)[0] patches_glob_path = os.path.join(os.environ['PATCHES_DIR'], f'{FILEID}_files', '*', '*.jpeg') patches_paths = glob.glob(patches_glob_path) sample = self.iclf_predict(patches_paths) torch.set_grad_enabled(True) node_feat, adjs, masks = Predictor.preparefeatureLabel(sample['image'], sample['adj_s'], self.device) pred,labels,loss,graphcam_tensors = self.model.forward(node_feat=node_feat, labels=None, adj=adjs, mask=masks, graphcam_flag=True, to_file=False) patches_coords = sample['c_idx'][0] viz_dict = self.get_graphcams(graphcam_tensors, patches_coords, slide_path, FILEID) return self.label_map_inv[pred.item()], viz_dict def iclf_predict(self, patches_paths): feats_list = [] batch_size = 128 num_workers = 0 args = edict({'batch_size':batch_size, 'num_workers':num_workers} ) dataloader, bag_size = bag_dataset(args, patches_paths) with torch.no_grad(): for iteration, batch in enumerate(dataloader): patches = batch['input'].float().to(self.device) feats, classes = self.i_classifier(patches) #feats = feats.cpu().numpy() feats_list.extend(feats) output = torch.stack(feats_list, dim=0).to(self.device) # save adjacent matrix adj_s = adj_matrix(patches_paths, output) patch_infos = [] for path in patches_paths: x, y = path.split('/')[-1].split('.')[0].split('_') patch_infos.append((x,y)) preds = {'image': [output], 'adj_s': [adj_s], 'c_idx': [patch_infos]} return preds def get_graphcams(self, graphcam_tensors, patches_coords, slide_path, FILEID): label_map = self.classdict label_name_from_id = self.label_map_inv n_class = len(label_map) p = graphcam_tensors['prob'].cpu().detach().numpy()[0] ori = openslide.OpenSlide(slide_path) width, height = ori.dimensions REDUCTION_FACTOR = 20 w, h = int(width/512), int(height/512) w_r, h_r = int(width/20), int(height/20) resized_img = ori.get_thumbnail((width,height))#ori.get_thumbnail((w_r,h_r)) resized_img = resized_img.resize((w_r,h_r)) ratio_w, ratio_h = width/resized_img.width, height/resized_img.height #print('ratios ', ratio_w, ratio_h) w_s, h_s = float(512/REDUCTION_FACTOR), float(512/REDUCTION_FACTOR) patches = [] xmax, ymax = 0, 0 for patch_coords in patches_coords: x, y = patch_coords if xmax < int(x): xmax = int(x) if ymax < int(y): ymax = int(y) patches.append('{}_{}.jpeg'.format(x,y)) output_img = np.asarray(resized_img)[:,:,::-1].copy() #-----------------------------------------------------------------------------------------------------# # GraphCAM #print('visulize GraphCAM') assign_matrix = graphcam_tensors['s_matrix_ori'] m = nn.Softmax(dim=1) assign_matrix = m(assign_matrix) # Thresholding for better visualization p = np.clip(p, 0.4, 1) output_img_copy =np.copy(output_img) gray = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) image_transformer_attribution = (output_img_copy - output_img_copy.min()) / (output_img_copy.max() - output_img_copy.min()) cam_matrices = [] masks = [] visualizations = [] viz_dict = dict() SAMPLE_VIZ_DIR = os.path.join(os.environ['GRAPHCAM_DIR'], FILEID) os.makedirs(SAMPLE_VIZ_DIR, exist_ok=True) for class_i in range(n_class): # Load graphcam for each class cam_matrix = graphcam_tensors[f'cam_{class_i}'] cam_matrix = torch.mm(assign_matrix, cam_matrix.transpose(1,0)) cam_matrix = cam_matrix.cpu() # Normalize the graphcam cam_matrix = (cam_matrix - cam_matrix.min()) / (cam_matrix.max() - cam_matrix.min()) cam_matrix = cam_matrix.detach().numpy() cam_matrix = p[class_i] * cam_matrix cam_matrix = np.clip(cam_matrix, 0, 1) mask = cam_to_mask(gray, patches, cam_matrix, w, h, w_s, h_s) vis = show_cam_on_image(image_transformer_attribution, mask) vis = np.uint8(255 * vis) cam_matrices.append(cam_matrix) masks.append(mask) visualizations.append(vis) viz_dict['{}'.format(label_name_from_id[class_i]) ] = vis cv2.imwrite(os.path.join( SAMPLE_VIZ_DIR, '{}_all_types_cam_{}.png'.format(FILEID, label_name_from_id[class_i] ) ), vis) h, w, _ = output_img.shape if h > w: vis_merge = cv2.hconcat([output_img] + visualizations) else: vis_merge = cv2.vconcat([output_img] + visualizations) cv2.imwrite(os.path.join( SAMPLE_VIZ_DIR, '{}_all_types_cam_all.png'.format(FILEID)), vis_merge) viz_dict['ALL'] = vis_merge cv2.imwrite(os.path.join( SAMPLE_VIZ_DIR, '{}_all_types_ori.png'.format(FILEID ) ), output_img) viz_dict['ORI'] = output_img return viz_dict def preparefeatureLabel(batch_graph, batch_adjs, device='cpu'): batch_size = len(batch_graph) max_node_num = 0 for i in range(batch_size): max_node_num = max(max_node_num, batch_graph[i].shape[0]) masks = torch.zeros(batch_size, max_node_num) adjs = torch.zeros(batch_size, max_node_num, max_node_num) batch_node_feat = torch.zeros(batch_size, max_node_num, 512) for i in range(batch_size): cur_node_num = batch_graph[i].shape[0] #node attribute feature tmp_node_fea = batch_graph[i] batch_node_feat[i, 0:cur_node_num] = tmp_node_fea #adjs adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i] #masks masks[i,0:cur_node_num] = 1 node_feat = batch_node_feat.to() adjs = adjs.to(device) masks = masks.to(device) return node_feat, adjs, masks def __init_graph_transformer(self, graph_transformer_weights): n_class = len(self.classdict) model = Classifier(n_class) model = nn.DataParallel(model) model.load_state_dict(torch.load(graph_transformer_weights, map_location=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) )) if torch.cuda.is_available(): model = model.cuda() self.model = model def __init_iclf(self, iclf_weights, backbone='resnet18'): if backbone == 'resnet18': resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d) num_feats = 512 if backbone == 'resnet34': resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d) num_feats = 512 if backbone == 'resnet50': resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d) num_feats = 2048 if backbone == 'resnet101': resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d) num_feats = 2048 for param in resnet.parameters(): param.requires_grad = False resnet.fc = nn.Identity() i_classifier = cl.IClassifier(resnet, num_feats, output_class=2).to(self.device) # load feature extractor state_dict_weights = torch.load(iclf_weights, map_location=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )) state_dict_init = i_classifier.state_dict() new_state_dict = OrderedDict() for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()): if 'features' not in k: continue name = k_0 new_state_dict[name] = v i_classifier.load_state_dict(new_state_dict, strict=False) self.i_classifier = i_classifier #0 load metadata dicitonary for class names #1 TILE THE IMAGE #2 FEED IT TO FEATURE EXTRACTOR #3 PRODUCE GRAPH #4 predict graphcams import subprocess import argparse import os import shutil if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch Classification') parser.add_argument('--slide_path', type=str, help='path to the WSI slide') args = parser.parse_args() predictor = Predictor() predicted_class, viz_dict = predictor.predict(args.slide_path) print('Class prediction is: ', predicted_class)