Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from torchmetrics.functional import dice, jaccard_index, accuracy | |
from segmentation_models_pytorch.losses import DiceLoss, TverskyLoss, FocalLoss | |
from src.medicalDataLoader import MedicalImageDataset | |
from src.utils import getTargetSegmentation, plot_img | |
from UNET_perso import UNET | |
## Parameters & Hyperparameters ## | |
EPOCHS = 2 | |
BATCH_SIZE_TRAIN = 8 | |
BATCH_SIZE_VAL = 8 | |
LR = 1e-3 | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
torch.cuda.empty_cache() | |
## Model ## | |
model = UNET(in_channels=1, out_channels=4).to(DEVICE) | |
## Loss ## | |
lossCE = nn.CrossEntropyLoss().to(DEVICE) | |
lossDice = DiceLoss(mode='multiclass').to(DEVICE) | |
## optimizer ## | |
#optimizer = torch.optim.Adam(model.parameters(), lr=LR) | |
optimizer = torch.optim.NAdam(model.parameters(), lr=LR) | |
transform = transforms.Compose([transforms.ToTensor()]) | |
ROOT_DIR = './Data' | |
train_set = MedicalImageDataset('train', | |
ROOT_DIR, | |
transform=transform, | |
mask_transform=transform, | |
augment=True, | |
equalize=False) | |
train_loader = DataLoader(train_set, | |
batch_size=BATCH_SIZE_TRAIN, | |
shuffle=True) | |
val_set = MedicalImageDataset('val', | |
ROOT_DIR, | |
transform=transform, | |
mask_transform=transform, | |
equalize=False) | |
val_loader = DataLoader(val_set, | |
batch_size=BATCH_SIZE_VAL, | |
shuffle=False) | |
test_set = MedicalImageDataset('test', | |
ROOT_DIR, | |
transform=transform, | |
mask_transform=transform, | |
equalize=False) | |
test_loader = DataLoader(test_set, | |
batch_size=BATCH_SIZE_VAL, | |
shuffle=False) | |
def train(dataLoader, model, optimizer, epoch, loss_fn1, loss_fn2=None): | |
print(f'~~~ train for epoch {epoch} ~~~') | |
model.train() | |
loop = tqdm(dataLoader) | |
train_loss = 0 | |
for i, (img, labels, name) in enumerate(loop): | |
#if torch.cuda.is_available(): | |
labels = getTargetSegmentation(labels) | |
img, labels = img.to(DEVICE), labels.to(DEVICE) | |
yPred = model(img) | |
if loss_fn2!=None: | |
loss = 0.5*loss_fn1(yPred, labels) + 0.5*loss_fn2(yPred, labels) | |
else : loss = loss_fn1(yPred, labels) | |
train_loss += loss.item() | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
loop.set_postfix(loss=loss.item()/len(dataLoader)) | |
print('total train loss : {:.4f}\n'.format(train_loss/len(dataLoader.dataset))) | |
return model, train_loss/len(dataLoader.dataset) | |
def test(dataLoader, model, loss_fn, epoch): | |
print(f'~~~ validation for epoch {epoch} ~~~') | |
model.eval() | |
size = len(dataLoader) | |
loop = tqdm(dataLoader) | |
test_loss = 0 | |
Acc = 0 | |
Dsc1, Dsc2, Dsc3 = 0, 0, 0 | |
IOU1, IOU2, IOU3 = 0, 0, 0 | |
for i, (img, labels, name) in enumerate(loop): | |
#if torch.cuda.is_available(): | |
labels = getTargetSegmentation(labels) | |
img, labels = img.to(DEVICE), labels.to(DEVICE) | |
yPred = model(img) | |
loss = loss_fn(yPred, labels) | |
test_loss += loss.item() | |
loop.set_postfix(loss=loss.item()/len(dataLoader)) | |
Dsc = dice(yPred, labels, average='none', num_classes=4).cpu() | |
IOU = jaccard_index(yPred, labels, task='multiclass', average='none', num_classes=4).cpu() | |
Dsc1 += Dsc[1] | |
Dsc2 += Dsc[2] | |
Dsc3 += Dsc[3] | |
IOU1 += IOU[1] | |
IOU2 += IOU[2] | |
IOU3 += IOU[3] | |
print('total test loss : {:.4f}\nDice score 1 : {:.4f} | Dice score 2 : {:.4f} | Dice score 3 : {:.4f}\nIOU 1 : {:.4f} | IOU 2 : {:.4f} | IOU 3 : {:.4f}\n'.format(test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size)) | |
return test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size | |
def main(train_loader, test_loader, model, optimizer, loss1, loss2): | |
train_loss_lst, test_loss_lst = [], [] | |
Dsc1_lst, Dsc2_lst, Dsc3_lst = [], [], [] | |
IOU1_lst, IOU2_lst, IOU3_lst = [], [], [] | |
for i in range(EPOCHS): | |
model, train_loss = train(train_loader, model, optimizer, i+1, loss_fn1=loss1, loss_fn2=loss2) | |
test_loss, Dsc1, Dsc2, Dsc3, IOU1, IOU2, IOU3 = test(test_loader, model, loss1, i+1) | |
train_loss_lst.append(train_loss) | |
test_loss_lst.append(test_loss) | |
Dsc1_lst.append(Dsc1) | |
Dsc2_lst.append(Dsc2) | |
Dsc3_lst.append(Dsc3) | |
IOU1_lst.append(IOU1) | |
IOU2_lst.append(IOU2) | |
IOU3_lst.append(IOU3) | |
return model | |
if __name__=='__main__': | |
mdoel = main(train_loader, test_loader, model, optimizer, loss1=lossCE, loss2=lossDice) | |
plot_img(test_loader, 8, model) |