import os import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torchmetrics.functional import dice, jaccard_index, accuracy from segmentation_models_pytorch.losses import DiceLoss, TverskyLoss, FocalLoss from src.medicalDataLoader import MedicalImageDataset from src.utils import getTargetSegmentation, plot_img from UNET_perso import UNET ## Parameters & Hyperparameters ## EPOCHS = 2 BATCH_SIZE_TRAIN = 8 BATCH_SIZE_VAL = 8 LR = 1e-3 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' torch.cuda.empty_cache() ## Model ## model = UNET(in_channels=1, out_channels=4).to(DEVICE) ## Loss ## lossCE = nn.CrossEntropyLoss().to(DEVICE) lossDice = DiceLoss(mode='multiclass').to(DEVICE) ## optimizer ## #optimizer = torch.optim.Adam(model.parameters(), lr=LR) optimizer = torch.optim.NAdam(model.parameters(), lr=LR) transform = transforms.Compose([transforms.ToTensor()]) ROOT_DIR = './Data' train_set = MedicalImageDataset('train', ROOT_DIR, transform=transform, mask_transform=transform, augment=True, equalize=False) train_loader = DataLoader(train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True) val_set = MedicalImageDataset('val', ROOT_DIR, transform=transform, mask_transform=transform, equalize=False) val_loader = DataLoader(val_set, batch_size=BATCH_SIZE_VAL, shuffle=False) test_set = MedicalImageDataset('test', ROOT_DIR, transform=transform, mask_transform=transform, equalize=False) test_loader = DataLoader(test_set, batch_size=BATCH_SIZE_VAL, shuffle=False) def train(dataLoader, model, optimizer, epoch, loss_fn1, loss_fn2=None): print(f'~~~ train for epoch {epoch} ~~~') model.train() loop = tqdm(dataLoader) train_loss = 0 for i, (img, labels, name) in enumerate(loop): #if torch.cuda.is_available(): labels = getTargetSegmentation(labels) img, labels = img.to(DEVICE), labels.to(DEVICE) yPred = model(img) if loss_fn2!=None: loss = 0.5*loss_fn1(yPred, labels) + 0.5*loss_fn2(yPred, labels) else : loss = loss_fn1(yPred, labels) train_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() loop.set_postfix(loss=loss.item()/len(dataLoader)) print('total train loss : {:.4f}\n'.format(train_loss/len(dataLoader.dataset))) return model, train_loss/len(dataLoader.dataset) def test(dataLoader, model, loss_fn, epoch): print(f'~~~ validation for epoch {epoch} ~~~') model.eval() size = len(dataLoader) loop = tqdm(dataLoader) test_loss = 0 Acc = 0 Dsc1, Dsc2, Dsc3 = 0, 0, 0 IOU1, IOU2, IOU3 = 0, 0, 0 for i, (img, labels, name) in enumerate(loop): #if torch.cuda.is_available(): labels = getTargetSegmentation(labels) img, labels = img.to(DEVICE), labels.to(DEVICE) yPred = model(img) loss = loss_fn(yPred, labels) test_loss += loss.item() loop.set_postfix(loss=loss.item()/len(dataLoader)) Dsc = dice(yPred, labels, average='none', num_classes=4).cpu() IOU = jaccard_index(yPred, labels, task='multiclass', average='none', num_classes=4).cpu() Dsc1 += Dsc[1] Dsc2 += Dsc[2] Dsc3 += Dsc[3] IOU1 += IOU[1] IOU2 += IOU[2] IOU3 += IOU[3] print('total test loss : {:.4f}\nDice score 1 : {:.4f} | Dice score 2 : {:.4f} | Dice score 3 : {:.4f}\nIOU 1 : {:.4f} | IOU 2 : {:.4f} | IOU 3 : {:.4f}\n'.format(test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size)) return test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size def main(train_loader, test_loader, model, optimizer, loss1, loss2): train_loss_lst, test_loss_lst = [], [] Dsc1_lst, Dsc2_lst, Dsc3_lst = [], [], [] IOU1_lst, IOU2_lst, IOU3_lst = [], [], [] for i in range(EPOCHS): model, train_loss = train(train_loader, model, optimizer, i+1, loss_fn1=loss1, loss_fn2=loss2) test_loss, Dsc1, Dsc2, Dsc3, IOU1, IOU2, IOU3 = test(test_loader, model, loss1, i+1) train_loss_lst.append(train_loss) test_loss_lst.append(test_loss) Dsc1_lst.append(Dsc1) Dsc2_lst.append(Dsc2) Dsc3_lst.append(Dsc3) IOU1_lst.append(IOU1) IOU2_lst.append(IOU2) IOU3_lst.append(IOU3) return model if __name__=='__main__': mdoel = main(train_loader, test_loader, model, optimizer, loss1=lossCE, loss2=lossDice) plot_img(test_loader, 8, model)