thov's picture
minor changement
7e2ae4e
raw
history blame
No virus
8.72 kB
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import os
from os.path import isfile, join
from medpy.metric.binary import dc, hd, asd, assd
import matplotlib.pyplot as plt
from IPython.display import Image, display
labels = {0: 'Background', 1: 'Foreground'}
def computeDSC(pred, gt):
dscAll = []
#pdb.set_trace()
for i_b in range(pred.shape[0]):
pred_id = pred[i_b, 1, :]
gt_id = gt[i_b, 0, :]
dscAll.append(dc(pred_id.cpu().data.numpy(), gt_id.cpu().data.numpy()))
DSC = np.asarray(dscAll)
return DSC.mean()
def getImageImageList(imagesFolder):
if os.path.exists(imagesFolder):
imageNames = [f for f in os.listdir(imagesFolder) if isfile(join(imagesFolder, f))]
imageNames.sort()
return imageNames
def to_var(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
def DicesToDice(Dices):
sums = Dices.sum(dim=0)
return (2 * sums[0] + 1e-8) / (sums[1] + 1e-8)
def predToSegmentation(pred):
Max = pred.max(dim=1, keepdim=True)[0]
x = pred / Max
# pdb.set_trace()
return (x == 1).float()
def getTargetSegmentation(batch):
# input is 1-channel of values between 0 and 1
# values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647
# output is 1 channel of discrete values : 0, 1, 2 and 3
denom = 0.33333334 # for ACDC this value
return (batch / denom).round().long().squeeze()
from scipy import ndimage
def inference(net, img_batch, modelName, epoch):
total = len(img_batch)
net.eval()
softMax = nn.Softmax().cuda()
CE_loss = nn.CrossEntropyLoss().cuda()
losses = []
for i, data in enumerate(img_batch):
printProgressBar(i, total, prefix="[Inference] Getting segmentations...", length=30)
images, labels, img_names = data
images = to_var(images)
labels = to_var(labels)
net_predictions = net(images)
segmentation_classes = getTargetSegmentation(labels)
CE_loss_value = CE_loss(net_predictions, segmentation_classes)
losses.append(CE_loss_value.cpu().data.numpy())
pred_y = softMax(net_predictions)
masks = torch.argmax(pred_y, dim=1)
path = os.path.join('./Results/Images/', modelName, str(epoch))
if not os.path.exists(path):
os.makedirs(path)
torchvision.utils.save_image(
torch.cat([images.data, labels.data, masks.view(labels.shape[0], 1, 256, 256).data / 3.0]),
os.path.join(path, str(i) + '.png'), padding=0)
printProgressBar(total, total, done="[Inference] Segmentation Done !")
losses = np.asarray(losses)
return losses.mean()
class MaskToTensor(object):
def __call__(self, img):
return torch.from_numpy(np.array(img, dtype=np.int32)).float()
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
def check_accuracy(loader, model, device="cuda"):
num_correct = 0
num_pixels = 0
dice_score = 0
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device).unsqueeze(1)
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
num_correct += (preds == y).sum()
num_pixels += torch.numel(preds)
dice_score += (2 * (preds * y).sum()) / (
(preds + y).sum() + 1e-8
)
print(
f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
)
print(f"Dice score: {dice_score/len(loader)}")
model.train()
def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
model.eval()
for idx, (x, y) in enumerate(loader):
x = x.to(device=device)
with torch.no_grad():
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
torchvision.utils.save_image(
preds, f"{folder}/pred_{idx}.png"
)
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
model.train()
# converting tensor to image
def image_convert(image):
image = image.clone().cpu().numpy()
image = image.transpose((1,2,0))
image = (image * 255)
return image
def mask_convert(mask):
mask = mask.clone().cpu().detach().numpy()
return np.squeeze(mask)
#If model is true, this will run inference on some test image and show the output on a plot
def plot_img(loader, no_, model=None):
images, target, name = next(iter(loader))
ind = np.random.choice(range(loader.batch_size))
data= to_var(images)
for idx in range(0,no_):
plt.figure(figsize=(12,12))
#Images
image = image_convert(images[idx])
plt.subplot(1,3,1)
plt.imshow(image)
plt.title('Original Image')
#Ground truth target mask
mask = mask_convert(target[idx])
plt.subplot(1,3,2)
plt.imshow(mask)
plt.title('Original Mask')
if model is None:
#superposition with target mask
plt.subplot(1,3,3)
plt.imshow(image)
plt.imshow(mask,alpha=0.6)
plt.title('Superposition')
else:
softMax = nn.Softmax().cuda()
#showing prediction mask
plt.subplot(1,3,3)
#make a prediction bases on the previous image
yhat = model(data)
pred_y = softMax(yhat)
masks = torch.argmax(pred_y, dim=1)
plt.imshow(mask_convert(masks[idx]))
plt.title('Prediction')
plt.show()
"""
def get_loaders(root_dir, batch_size, NUM_WORKERS, PIN_MEMORY, test = False):
train_transform = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Rotate(limit=35, p=1.0),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.1),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
val_transform = A.Compose(
[
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Normalize(
mean=[0.0, 0.0, 0.0],
std=[1.0, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)
## DUE TO THE CUSTOM LOADING CLASS, HE NEED TO USE TO STEP TO LOAD DATA
train_set_full = medicalDataLoader.MedicalImageDataset('train',
root_dir,
transform=train_transform,
mask_transform=train_transform,
augment=False,
equalize=False)
train_loader_full = DataLoader(train_set_full,
batch_size=batch_size,
worker_init_fn=np.random.seed(0),
num_workers= 0,
shuffle=True)
val_set = medicalDataLoader.MedicalImageDataset('val',
root_dir,
transform=val_transform,
mask_transform=val_transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=batch_size,
worker_init_fn=np.random.seed(0),
num_workers = 0,
shuffle=False)
if test:
test_set = medicalDataLoader.MedicalImageDataset('test',
root_dir,
transform=None,
mask_transform=None,
equalize=False)
test_loader = DataLoader(test_set,
batch_size=batch_size,
num_workers=0,
shuffle=False)
return test_loader
return train_loader_full, val_loader"""