Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import torchvision | |
import os | |
from os.path import isfile, join | |
from medpy.metric.binary import dc, hd, asd, assd | |
import matplotlib.pyplot as plt | |
from IPython.display import Image, display | |
labels = {0: 'Background', 1: 'Foreground'} | |
def computeDSC(pred, gt): | |
dscAll = [] | |
#pdb.set_trace() | |
for i_b in range(pred.shape[0]): | |
pred_id = pred[i_b, 1, :] | |
gt_id = gt[i_b, 0, :] | |
dscAll.append(dc(pred_id.cpu().data.numpy(), gt_id.cpu().data.numpy())) | |
DSC = np.asarray(dscAll) | |
return DSC.mean() | |
def getImageImageList(imagesFolder): | |
if os.path.exists(imagesFolder): | |
imageNames = [f for f in os.listdir(imagesFolder) if isfile(join(imagesFolder, f))] | |
imageNames.sort() | |
return imageNames | |
def to_var(x): | |
if torch.cuda.is_available(): | |
x = x.cuda() | |
return Variable(x) | |
def DicesToDice(Dices): | |
sums = Dices.sum(dim=0) | |
return (2 * sums[0] + 1e-8) / (sums[1] + 1e-8) | |
def predToSegmentation(pred): | |
Max = pred.max(dim=1, keepdim=True)[0] | |
x = pred / Max | |
# pdb.set_trace() | |
return (x == 1).float() | |
def getTargetSegmentation(batch): | |
# input is 1-channel of values between 0 and 1 | |
# values are as follows : 0, 0.33333334, 0.6666667 and 0.94117647 | |
# output is 1 channel of discrete values : 0, 1, 2 and 3 | |
denom = 0.33333334 # for ACDC this value | |
return (batch / denom).round().long().squeeze() | |
from scipy import ndimage | |
def inference(net, img_batch, modelName, epoch): | |
total = len(img_batch) | |
net.eval() | |
softMax = nn.Softmax().cuda() | |
CE_loss = nn.CrossEntropyLoss().cuda() | |
losses = [] | |
for i, data in enumerate(img_batch): | |
printProgressBar(i, total, prefix="[Inference] Getting segmentations...", length=30) | |
images, labels, img_names = data | |
images = to_var(images) | |
labels = to_var(labels) | |
net_predictions = net(images) | |
segmentation_classes = getTargetSegmentation(labels) | |
CE_loss_value = CE_loss(net_predictions, segmentation_classes) | |
losses.append(CE_loss_value.cpu().data.numpy()) | |
pred_y = softMax(net_predictions) | |
masks = torch.argmax(pred_y, dim=1) | |
path = os.path.join('./Results/Images/', modelName, str(epoch)) | |
if not os.path.exists(path): | |
os.makedirs(path) | |
torchvision.utils.save_image( | |
torch.cat([images.data, labels.data, masks.view(labels.shape[0], 1, 256, 256).data / 3.0]), | |
os.path.join(path, str(i) + '.png'), padding=0) | |
printProgressBar(total, total, done="[Inference] Segmentation Done !") | |
losses = np.asarray(losses) | |
return losses.mean() | |
class MaskToTensor(object): | |
def __call__(self, img): | |
return torch.from_numpy(np.array(img, dtype=np.int32)).float() | |
def save_checkpoint(state, filename="my_checkpoint.pth.tar"): | |
print("=> Saving checkpoint") | |
torch.save(state, filename) | |
def load_checkpoint(checkpoint, model): | |
print("=> Loading checkpoint") | |
model.load_state_dict(checkpoint["state_dict"]) | |
def check_accuracy(loader, model, device="cuda"): | |
num_correct = 0 | |
num_pixels = 0 | |
dice_score = 0 | |
model.eval() | |
with torch.no_grad(): | |
for x, y in loader: | |
x = x.to(device) | |
y = y.to(device).unsqueeze(1) | |
preds = torch.sigmoid(model(x)) | |
preds = (preds > 0.5).float() | |
num_correct += (preds == y).sum() | |
num_pixels += torch.numel(preds) | |
dice_score += (2 * (preds * y).sum()) / ( | |
(preds + y).sum() + 1e-8 | |
) | |
print( | |
f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}" | |
) | |
print(f"Dice score: {dice_score/len(loader)}") | |
model.train() | |
def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"): | |
model.eval() | |
for idx, (x, y) in enumerate(loader): | |
x = x.to(device=device) | |
with torch.no_grad(): | |
preds = torch.sigmoid(model(x)) | |
preds = (preds > 0.5).float() | |
torchvision.utils.save_image( | |
preds, f"{folder}/pred_{idx}.png" | |
) | |
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png") | |
model.train() | |
# converting tensor to image | |
def image_convert(image): | |
image = image.clone().cpu().numpy() | |
image = image.transpose((1,2,0)) | |
image = (image * 255) | |
return image | |
def mask_convert(mask): | |
mask = mask.clone().cpu().detach().numpy() | |
return np.squeeze(mask) | |
#If model is true, this will run inference on some test image and show the output on a plot | |
def plot_img(loader, no_, model=None): | |
images, target, name = next(iter(loader)) | |
ind = np.random.choice(range(loader.batch_size)) | |
data= to_var(images) | |
for idx in range(0,no_): | |
plt.figure(figsize=(12,12)) | |
#Images | |
image = image_convert(images[idx]) | |
plt.subplot(1,3,1) | |
plt.imshow(image) | |
plt.title('Original Image') | |
#Ground truth target mask | |
mask = mask_convert(target[idx]) | |
plt.subplot(1,3,2) | |
plt.imshow(mask) | |
plt.title('Original Mask') | |
if model is None: | |
#superposition with target mask | |
plt.subplot(1,3,3) | |
plt.imshow(image) | |
plt.imshow(mask,alpha=0.6) | |
plt.title('Superposition') | |
else: | |
softMax = nn.Softmax().cuda() | |
#showing prediction mask | |
plt.subplot(1,3,3) | |
#make a prediction bases on the previous image | |
yhat = model(data) | |
pred_y = softMax(yhat) | |
masks = torch.argmax(pred_y, dim=1) | |
plt.imshow(mask_convert(masks[idx])) | |
plt.title('Prediction') | |
plt.show() | |
""" | |
def get_loaders(root_dir, batch_size, NUM_WORKERS, PIN_MEMORY, test = False): | |
train_transform = A.Compose( | |
[ | |
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH), | |
A.Rotate(limit=35, p=1.0), | |
A.HorizontalFlip(p=0.5), | |
A.VerticalFlip(p=0.1), | |
A.Normalize( | |
mean=[0.0, 0.0, 0.0], | |
std=[1.0, 1.0, 1.0], | |
max_pixel_value=255.0, | |
), | |
ToTensorV2(), | |
], | |
) | |
val_transform = A.Compose( | |
[ | |
A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH), | |
A.Normalize( | |
mean=[0.0, 0.0, 0.0], | |
std=[1.0, 1.0, 1.0], | |
max_pixel_value=255.0, | |
), | |
ToTensorV2(), | |
], | |
) | |
## DUE TO THE CUSTOM LOADING CLASS, HE NEED TO USE TO STEP TO LOAD DATA | |
train_set_full = medicalDataLoader.MedicalImageDataset('train', | |
root_dir, | |
transform=train_transform, | |
mask_transform=train_transform, | |
augment=False, | |
equalize=False) | |
train_loader_full = DataLoader(train_set_full, | |
batch_size=batch_size, | |
worker_init_fn=np.random.seed(0), | |
num_workers= 0, | |
shuffle=True) | |
val_set = medicalDataLoader.MedicalImageDataset('val', | |
root_dir, | |
transform=val_transform, | |
mask_transform=val_transform, | |
equalize=False) | |
val_loader = DataLoader(val_set, | |
batch_size=batch_size, | |
worker_init_fn=np.random.seed(0), | |
num_workers = 0, | |
shuffle=False) | |
if test: | |
test_set = medicalDataLoader.MedicalImageDataset('test', | |
root_dir, | |
transform=None, | |
mask_transform=None, | |
equalize=False) | |
test_loader = DataLoader(test_set, | |
batch_size=batch_size, | |
num_workers=0, | |
shuffle=False) | |
return test_loader | |
return train_loader_full, val_loader""" |