thov's picture
add training
e6f4cd4
raw history blame
No virus
5.27 kB
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchmetrics.functional import dice, jaccard_index, accuracy
from segmentation_models_pytorch.losses import DiceLoss, TverskyLoss, FocalLoss
from src.medicalDataLoader import MedicalImageDataset
from src.utils import getTargetSegmentation, plot_img
from UNET_perso import UNET
## Parameters & Hyperparameters ##
EPOCHS = 2
BATCH_SIZE_TRAIN = 8
BATCH_SIZE_VAL = 8
LR = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()
## Model ##
model = UNET(in_channels=1, out_channels=4).to(DEVICE)
## Loss ##
lossCE = nn.CrossEntropyLoss().to(DEVICE)
lossDice = DiceLoss(mode='multiclass').to(DEVICE)
## optimizer ##
#optimizer = torch.optim.Adam(model.parameters(), lr=LR)
optimizer = torch.optim.NAdam(model.parameters(), lr=LR)
transform = transforms.Compose([transforms.ToTensor()])
ROOT_DIR = './Data'
train_set = MedicalImageDataset('train',
ROOT_DIR,
transform=transform,
mask_transform=transform,
augment=True,
equalize=False)
train_loader = DataLoader(train_set,
batch_size=BATCH_SIZE_TRAIN,
shuffle=True)
val_set = MedicalImageDataset('val',
ROOT_DIR,
transform=transform,
mask_transform=transform,
equalize=False)
val_loader = DataLoader(val_set,
batch_size=BATCH_SIZE_VAL,
shuffle=False)
test_set = MedicalImageDataset('test',
ROOT_DIR,
transform=transform,
mask_transform=transform,
equalize=False)
test_loader = DataLoader(test_set,
batch_size=BATCH_SIZE_VAL,
shuffle=False)
def train(dataLoader, model, optimizer, epoch, loss_fn1, loss_fn2=None):
print(f'~~~ train for epoch {epoch} ~~~')
model.train()
loop = tqdm(dataLoader)
train_loss = 0
for i, (img, labels, name) in enumerate(loop):
#if torch.cuda.is_available():
labels = getTargetSegmentation(labels)
img, labels = img.to(DEVICE), labels.to(DEVICE)
yPred = model(img)
if loss_fn2!=None:
loss = 0.5*loss_fn1(yPred, labels) + 0.5*loss_fn2(yPred, labels)
else : loss = loss_fn1(yPred, labels)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loop.set_postfix(loss=loss.item()/len(dataLoader))
print('total train loss : {:.4f}\n'.format(train_loss/len(dataLoader.dataset)))
return model, train_loss/len(dataLoader.dataset)
def test(dataLoader, model, loss_fn, epoch):
print(f'~~~ validation for epoch {epoch} ~~~')
model.eval()
size = len(dataLoader)
loop = tqdm(dataLoader)
test_loss = 0
Acc = 0
Dsc1, Dsc2, Dsc3 = 0, 0, 0
IOU1, IOU2, IOU3 = 0, 0, 0
for i, (img, labels, name) in enumerate(loop):
#if torch.cuda.is_available():
labels = getTargetSegmentation(labels)
img, labels = img.to(DEVICE), labels.to(DEVICE)
yPred = model(img)
loss = loss_fn(yPred, labels)
test_loss += loss.item()
loop.set_postfix(loss=loss.item()/len(dataLoader))
Dsc = dice(yPred, labels, average='none', num_classes=4).cpu()
IOU = jaccard_index(yPred, labels, task='multiclass', average='none', num_classes=4).cpu()
Dsc1 += Dsc[1]
Dsc2 += Dsc[2]
Dsc3 += Dsc[3]
IOU1 += IOU[1]
IOU2 += IOU[2]
IOU3 += IOU[3]
print('total test loss : {:.4f}\nDice score 1 : {:.4f} | Dice score 2 : {:.4f} | Dice score 3 : {:.4f}\nIOU 1 : {:.4f} | IOU 2 : {:.4f} | IOU 3 : {:.4f}\n'.format(test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size))
return test_loss/size, Dsc1/size, Dsc2/size, Dsc3/size, IOU1/size, IOU2/size, IOU3/size
def main(train_loader, test_loader, model, optimizer, loss1, loss2):
train_loss_lst, test_loss_lst = [], []
Dsc1_lst, Dsc2_lst, Dsc3_lst = [], [], []
IOU1_lst, IOU2_lst, IOU3_lst = [], [], []
for i in range(EPOCHS):
model, train_loss = train(train_loader, model, optimizer, i+1, loss_fn1=loss1, loss_fn2=loss2)
test_loss, Dsc1, Dsc2, Dsc3, IOU1, IOU2, IOU3 = test(test_loader, model, loss1, i+1)
train_loss_lst.append(train_loss)
test_loss_lst.append(test_loss)
Dsc1_lst.append(Dsc1)
Dsc2_lst.append(Dsc2)
Dsc3_lst.append(Dsc3)
IOU1_lst.append(IOU1)
IOU2_lst.append(IOU2)
IOU3_lst.append(IOU3)
return model
if __name__=='__main__':
mdoel = main(train_loader, test_loader, model, optimizer, loss1=lossCE, loss2=lossDice)
plot_img(test_loader, 8, model)