import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import torchvision import os import skimage.transform as skiTransf import scipy.io as sio import pdb import time import re from os.path import isfile, join import statistics from PIL import Image from medpy.metric.binary import dc, hd, asd, assd import scipy.spatial import matplotlib.pyplot as plt from IPython.display import Image, display from skimage import io import cv2 # from scipy.spatial.distance import directed_hausdorff labels = {0: 'Background', 1: 'Foreground'} def computeDSC(pred, gt): dscAll = [] #pdb.set_trace() for i_b in range(pred.shape[0]): pred_id = pred[i_b, 1, :] gt_id = gt[i_b, 0, :] dscAll.append(dc(pred_id.cpu().data.numpy(), gt_id.cpu().data.numpy())) DSC = np.asarray(dscAll) return DSC.mean() def getImageImageList(imagesFolder): if os.path.exists(imagesFolder): imageNames = [f for f in os.listdir(imagesFolder) if isfile(join(imagesFolder, f))] imageNames.sort() return imageNames def to_var(x): if torch.cuda.is_available(): x = x.cuda() return Variable(x) def DicesToDice(Dices): sums = Dices.sum(dim=0) return (2 * sums[0] + 1e-8) / (sums[1] + 1e-8) def predToSegmentation(pred): Max = pred.max(dim=1, keepdim=True)[0] x = pred / Max # pdb.set_trace() return (x == 1).float() def getTargetSegmentation(batch): # input is 1-channel of values between 0 and 1 # values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647 # output is 1 channel of discrete values : 0, 1, 2 and 3 denom = 0.33333334 # for ACDC this value return (batch / denom).round().long().squeeze() from scipy import ndimage def inference(net, img_batch, modelName, epoch): total = len(img_batch) net.eval() softMax = nn.Softmax().cuda() CE_loss = nn.CrossEntropyLoss().cuda() losses = [] for i, data in enumerate(img_batch): printProgressBar(i, total, prefix="[Inference] Getting segmentations...", length=30) images, labels, img_names = data images = to_var(images) labels = to_var(labels) net_predictions = net(images) segmentation_classes = getTargetSegmentation(labels) CE_loss_value = CE_loss(net_predictions, segmentation_classes) losses.append(CE_loss_value.cpu().data.numpy()) pred_y = softMax(net_predictions) masks = torch.argmax(pred_y, dim=1) path = os.path.join('./Results/Images/', modelName, str(epoch)) if not os.path.exists(path): os.makedirs(path) torchvision.utils.save_image( torch.cat([images.data, labels.data, masks.view(labels.shape[0], 1, 256, 256).data / 3.0]), os.path.join(path, str(i) + '.png'), padding=0) printProgressBar(total, total, done="[Inference] Segmentation Done !") losses = np.asarray(losses) return losses.mean() class MaskToTensor(object): def __call__(self, img): return torch.from_numpy(np.array(img, dtype=np.int32)).float() def save_checkpoint(state, filename="my_checkpoint.pth.tar"): print("=> Saving checkpoint") torch.save(state, filename) def load_checkpoint(checkpoint, model): print("=> Loading checkpoint") model.load_state_dict(checkpoint["state_dict"]) def check_accuracy(loader, model, device="cuda"): num_correct = 0 num_pixels = 0 dice_score = 0 model.eval() with torch.no_grad(): for x, y in loader: x = x.to(device) y = y.to(device).unsqueeze(1) preds = torch.sigmoid(model(x)) preds = (preds > 0.5).float() num_correct += (preds == y).sum() num_pixels += torch.numel(preds) dice_score += (2 * (preds * y).sum()) / ( (preds + y).sum() + 1e-8 ) print( f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}" ) print(f"Dice score: {dice_score/len(loader)}") model.train() def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"): model.eval() for idx, (x, y) in enumerate(loader): x = x.to(device=device) with torch.no_grad(): preds = torch.sigmoid(model(x)) preds = (preds > 0.5).float() torchvision.utils.save_image( preds, f"{folder}/pred_{idx}.png" ) torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png") model.train() # converting tensor to image def image_convert(image): image = image.clone().cpu().numpy() image = image.transpose((1,2,0)) image = (image * 255) return image def mask_convert(mask): mask = mask.clone().cpu().detach().numpy() return np.squeeze(mask) #If model is true, this will run inference on some test image and show the output on a plot def plot_img(loader, no_, model=None): images, target, name = next(iter(loader)) ind = np.random.choice(range(loader.batch_size)) data= to_var(images) for idx in range(0,no_): plt.figure(figsize=(12,12)) #Images image = image_convert(images[idx]) plt.subplot(1,3,1) plt.imshow(image) plt.title('Original Image') #Ground truth target mask mask = mask_convert(target[idx]) plt.subplot(1,3,2) plt.imshow(mask) plt.title('Original Mask') if model is None: #superposition with target mask plt.subplot(1,3,3) plt.imshow(image) plt.imshow(mask,alpha=0.6) plt.title('Superposition') else: softMax = nn.Softmax().cuda() #showing prediction mask plt.subplot(1,3,3) #make a prediction bases on the previous image yhat = model(data) pred_y = softMax(yhat) masks = torch.argmax(pred_y, dim=1) plt.imshow(mask_convert(masks[idx])) plt.title('Prediction') plt.show() """ def get_loaders(root_dir, batch_size, NUM_WORKERS, PIN_MEMORY, test = False): train_transform = A.Compose( [ A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH), A.Rotate(limit=35, p=1.0), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.1), A.Normalize( mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0, ), ToTensorV2(), ], ) val_transform = A.Compose( [ A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH), A.Normalize( mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0, ), ToTensorV2(), ], ) ## DUE TO THE CUSTOM LOADING CLASS, HE NEED TO USE TO STEP TO LOAD DATA train_set_full = medicalDataLoader.MedicalImageDataset('train', root_dir, transform=train_transform, mask_transform=train_transform, augment=False, equalize=False) train_loader_full = DataLoader(train_set_full, batch_size=batch_size, worker_init_fn=np.random.seed(0), num_workers= 0, shuffle=True) val_set = medicalDataLoader.MedicalImageDataset('val', root_dir, transform=val_transform, mask_transform=val_transform, equalize=False) val_loader = DataLoader(val_set, batch_size=batch_size, worker_init_fn=np.random.seed(0), num_workers = 0, shuffle=False) if test: test_set = medicalDataLoader.MedicalImageDataset('test', root_dir, transform=None, mask_transform=None, equalize=False) test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=0, shuffle=False) return test_loader return train_loader_full, val_loader"""