import numpy as np import torch import torch.nn as nn from torch.autograd import Variable import torchvision import os from os.path import isfile, join from medpy.metric.binary import dc, hd, asd, assd import matplotlib.pyplot as plt from IPython.display import Image, display labels = {0: 'Background', 1: 'Foreground'} def computeDSC(pred, gt): dscAll = [] #pdb.set_trace() for i_b in range(pred.shape[0]): pred_id = pred[i_b, 1, :] gt_id = gt[i_b, 0, :] dscAll.append(dc(pred_id.cpu().data.numpy(), gt_id.cpu().data.numpy())) DSC = np.asarray(dscAll) return DSC.mean() def getImageImageList(imagesFolder): if os.path.exists(imagesFolder): imageNames = [f for f in os.listdir(imagesFolder) if isfile(join(imagesFolder, f))] imageNames.sort() return imageNames def to_var(x): if torch.cuda.is_available(): x = x.cuda() return Variable(x) def DicesToDice(Dices): sums = Dices.sum(dim=0) return (2 * sums[0] + 1e-8) / (sums[1] + 1e-8) def predToSegmentation(pred): Max = pred.max(dim=1, keepdim=True)[0] x = pred / Max # pdb.set_trace() return (x == 1).float() def getTargetSegmentation(batch): # input is 1-channel of values between 0 and 1 # values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647 # output is 1 channel of discrete values : 0, 1, 2 and 3 denom = 0.33333334 # for ACDC this value return (batch / denom).round().long().squeeze() from scipy import ndimage def inference(net, img_batch, modelName, epoch): total = len(img_batch) net.eval() softMax = nn.Softmax().cuda() CE_loss = nn.CrossEntropyLoss().cuda() losses = [] for i, data in enumerate(img_batch): printProgressBar(i, total, prefix="[Inference] Getting segmentations...", length=30) images, labels, img_names = data images = to_var(images) labels = to_var(labels) net_predictions = net(images) segmentation_classes = getTargetSegmentation(labels) CE_loss_value = CE_loss(net_predictions, segmentation_classes) losses.append(CE_loss_value.cpu().data.numpy()) pred_y = softMax(net_predictions) masks = torch.argmax(pred_y, dim=1) path = os.path.join('./Results/Images/', modelName, str(epoch)) if not os.path.exists(path): os.makedirs(path) torchvision.utils.save_image( torch.cat([images.data, labels.data, masks.view(labels.shape[0], 1, 256, 256).data / 3.0]), os.path.join(path, str(i) + '.png'), padding=0) printProgressBar(total, total, done="[Inference] Segmentation Done !") losses = np.asarray(losses) return losses.mean() class MaskToTensor(object): def __call__(self, img): return torch.from_numpy(np.array(img, dtype=np.int32)).float() def save_checkpoint(state, filename="my_checkpoint.pth.tar"): print("=> Saving checkpoint") torch.save(state, filename) def load_checkpoint(checkpoint, model): print("=> Loading checkpoint") model.load_state_dict(checkpoint["state_dict"]) def check_accuracy(loader, model, device="cuda"): num_correct = 0 num_pixels = 0 dice_score = 0 model.eval() with torch.no_grad(): for x, y in loader: x = x.to(device) y = y.to(device).unsqueeze(1) preds = torch.sigmoid(model(x)) preds = (preds > 0.5).float() num_correct += (preds == y).sum() num_pixels += torch.numel(preds) dice_score += (2 * (preds * y).sum()) / ( (preds + y).sum() + 1e-8 ) print( f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}" ) print(f"Dice score: {dice_score/len(loader)}") model.train() def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"): model.eval() for idx, (x, y) in enumerate(loader): x = x.to(device=device) with torch.no_grad(): preds = torch.sigmoid(model(x)) preds = (preds > 0.5).float() torchvision.utils.save_image( preds, f"{folder}/pred_{idx}.png" ) torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png") model.train() # converting tensor to image def image_convert(image): image = image.clone().cpu().numpy() image = image.transpose((1,2,0)) image = (image * 255) return image def mask_convert(mask): mask = mask.clone().cpu().detach().numpy() return np.squeeze(mask) #If model is true, this will run inference on some test image and show the output on a plot def plot_img(loader, no_, model=None): images, target, name = next(iter(loader)) ind = np.random.choice(range(loader.batch_size)) data= to_var(images) for idx in range(0,no_): plt.figure(figsize=(12,12)) #Images image = image_convert(images[idx]) plt.subplot(1,3,1) plt.imshow(image) plt.title('Original Image') #Ground truth target mask mask = mask_convert(target[idx]) plt.subplot(1,3,2) plt.imshow(mask) plt.title('Original Mask') if model is None: #superposition with target mask plt.subplot(1,3,3) plt.imshow(image) plt.imshow(mask,alpha=0.6) plt.title('Superposition') else: softMax = nn.Softmax().cuda() #showing prediction mask plt.subplot(1,3,3) #make a prediction bases on the previous image yhat = model(data) pred_y = softMax(yhat) masks = torch.argmax(pred_y, dim=1) plt.imshow(mask_convert(masks[idx])) plt.title('Prediction') plt.show() """ def get_loaders(root_dir, batch_size, NUM_WORKERS, PIN_MEMORY, test = False): train_transform = A.Compose( [ A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH), A.Rotate(limit=35, p=1.0), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.1), A.Normalize( mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0, ), ToTensorV2(), ], ) val_transform = A.Compose( [ A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH), A.Normalize( mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0, ), ToTensorV2(), ], ) ## DUE TO THE CUSTOM LOADING CLASS, HE NEED TO USE TO STEP TO LOAD DATA train_set_full = medicalDataLoader.MedicalImageDataset('train', root_dir, transform=train_transform, mask_transform=train_transform, augment=False, equalize=False) train_loader_full = DataLoader(train_set_full, batch_size=batch_size, worker_init_fn=np.random.seed(0), num_workers= 0, shuffle=True) val_set = medicalDataLoader.MedicalImageDataset('val', root_dir, transform=val_transform, mask_transform=val_transform, equalize=False) val_loader = DataLoader(val_set, batch_size=batch_size, worker_init_fn=np.random.seed(0), num_workers = 0, shuffle=False) if test: test_set = medicalDataLoader.MedicalImageDataset('test', root_dir, transform=None, mask_transform=None, equalize=False) test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=0, shuffle=False) return test_loader return train_loader_full, val_loader"""