LVM-Med / segmentation_2d /train_R50_seg_adam_optimizer_2d.py
duynhm's picture
Initial commit
be2715b
import argparse
import logging
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from torch.optim.lr_scheduler import ExponentialLR
import os
from dataloader.dataset_ete import SegmentationDataset_train, SegmentationDataset
from utils.endtoend import dice_loss
from utils.func import (
parse_config,
load_config
)
from evaluate import evaluate, evaluate_3d_iou
#from models.segmentation import UNet
import segmentation_models_pytorch as smp
import numpy as np
import random
num_classes = 2
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
def train_net(net,
cfg,
trial,
device,
epochs: int = 30,
train_batch_size: int = 128,
val_batch_size: int = 128,
learning_rate: float = 0.1,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale = (224, 224),
amp: bool = True,
out_dir : str= './checkpoint/'):
# 1. Create dataset
train_dir_img = Path(cfg.dataloader.train_dir_img)
train_dir_mask = Path(cfg.dataloader.train_dir_mask)
val_dir_img = Path(cfg.dataloader.valid_dir_img)
val_dir_mask = Path(cfg.dataloader.valid_dir_mask)
test_dir_img = Path(cfg.dataloader.test_dir_img)
test_dir_mask = Path(cfg.dataloader.test_dir_mask)
non_label_text = cfg.dataloader.non_label
have_label_text = cfg.dataloader.have_label
dir_checkpoint = Path(out_dir)
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
train_dataset = SegmentationDataset_train(nonlabel_path= non_label_text, havelabel_path= have_label_text, dataset = cfg.base.dataset_name, scale= img_scale)
val_dataset = SegmentationDataset(name_dataset=cfg.base.dataset_name, images_dir = val_dir_img, masks_dir = val_dir_mask, scale = img_scale)
test_dataset = SegmentationDataset(name_dataset=cfg.base.dataset_name, images_dir = test_dir_img, masks_dir= test_dir_mask, scale = img_scale)
n_train = len(train_dataset)
n_val = len(val_dataset)
# 3. Create data loaders
loader_args = dict(num_workers=10, pin_memory=True)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, **loader_args)
import time
val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, batch_size=val_batch_size, **loader_args)
test_loader = DataLoader(test_dataset, shuffle=False, drop_last=True, **loader_args)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(dict(epochs=epochs, train_batch_size=train_batch_size, val_batch_size=val_batch_size, learning_rate=learning_rate,
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
amp=amp))
logging.info(f'''Starting training:
Epochs: {epochs}
Train batch size: {train_batch_size}
Val batch size: {val_batch_size}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
Images scaling: {img_scale}
Mixed Precision: {amp}
''')
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
# optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(cfg.train.beta1, cfg.train.beta2), eps=1e-08, weight_decay=cfg.train.weight_decay)
if cfg.train.scheduler:
print("Use scheduler")
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-05)
# optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-8)
# scheduler = ExponentialLR(optimizer, gamma=1.11)
# optimizer= optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0
best_value = 0
# 5. Begin training
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images = batch['image']
true_masks = batch['mask_ete']
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
with torch.cuda.amp.autocast(enabled=amp):
masks_pred = net(images)
loss = criterion(masks_pred, true_masks) \
+ dice_loss(F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, num_classes).permute(0, 3, 1, 2).float(),
multiclass=True)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
clip_value = 1
torch.nn.utils.clip_grad_norm_(net.parameters(), clip_value)
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()})
if cfg.train.scheduler:
scheduler.step()
# Evaluation round
if global_step % (n_train // (1 * train_batch_size)) == 0:
val_dice_score, val_iou_score = evaluate(net, val_loader, device, 1)
val_score = val_dice_score
if (val_score > best_value):
best_value = val_score
logging.info("New best dice score: {} at epochs {}".format(best_value, epoch+1))
torch.save(net.state_dict(), str(dir_checkpoint/'checkpoint_{}_{}_best_{}.pth'.format(cfg.base.dataset_name, cfg.base.original_checkpoint, str(trial))))
logging.info('Validation Dice score: {}, IoU score {}'.format(val_dice_score, val_iou_score))
if epoch + 1 == epochs:
val_dice_score, val_iou_score = evaluate(net, val_loader, device, 1)
logging.info('Validation Dice score: {}, IoU score {}'.format(val_dice_score, val_iou_score))
if save_checkpoint:
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
logging.info(f'Checkpoint {epoch + 1} saved!')
if epoch > 0 and epoch != (epochs % 2 - 1) :
os.remove( str(dir_checkpoint/'checkpoint_epoch{}.pth'.format(epoch)))
logging.info("Evalutating on test set")
logging.info("Loading best model on validation")
net.load_state_dict(torch.load(str(dir_checkpoint/'checkpoint_{}_{}_best_{}.pth'.format(cfg.base.dataset_name, cfg.base.original_checkpoint, str(trial)))))
test_dice, test_iou = evaluate(net, test_loader, device, 1)
logging.info("Test dice score {}, IoU score {}".format(test_dice, test_iou))
logging.info("Loading model at last epochs %d" %epochs)
net.load_state_dict(torch.load(str(dir_checkpoint/'checkpoint_epoch{}.pth'.format(epochs))))
test_dice_last, test_iou_last = evaluate(net, test_loader, device, 1)
logging.info("Test dice score {}, IoU score {}".format(test_dice_last, test_iou_last))
return test_dice, test_iou, test_dice_last, test_iou_last
def eval(cfg, out_dir, net, device, img_scale, trial):
test_dir_img = Path(cfg.dataloader.test_dir_img)
test_dir_mask = Path(cfg.dataloader.test_dir_mask)
test_dataset = SegmentationDataset(name_dataset=cfg.base.dataset_name, images_dir = test_dir_img, masks_dir= test_dir_mask, scale = img_scale)
loader_args = dict(num_workers=10, pin_memory=True)
test_loader = DataLoader(test_dataset, shuffle=False, drop_last=True, **loader_args)
dir_checkpoint = Path(out_dir)
print("Trial", trial+1)
logging.info("Evalutating on test set")
logging.info("Loading best model on validation")
net.load_state_dict(torch.load(str(dir_checkpoint/'checkpoint_{}_{}_best_{}.pth'.format(cfg.base.dataset_name, cfg.base.original_checkpoint, str(trial)))))
test_dice, test_iou = evaluate(net, test_loader, device, 1)
logging.info("Test dice score {}, IoU score {}".format(test_dice, test_iou))
return test_dice, test_iou
#if __name__ == '__main__':
def train_2d_R50(yml_args, cfg):
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
cuda_string = 'cuda:' + cfg.base.gpu_id
device = torch.device(cuda_string if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
try:
_2d_dices = []
_2d_ious = []
_2d_dices_last = []
_2d_ious_last = []
if not yml_args.use_test_mode:
for trial in range(3):
print ("----"*3)
if cfg.base.original_checkpoint == "scratch":
net = smp.Unet(encoder_name="resnet50", encoder_weights=None, in_channels=3, classes=num_classes)
else:
print ("Using pre-trained models from", cfg.base.original_checkpoint)
net = smp.Unet(encoder_name="resnet50", encoder_weights=cfg.base.original_checkpoint,
in_channels=3, classes=num_classes)
net.to(device=device)
print("Trial", trial + 1)
_2d_dice, _2d_iou, _2d_dice_last, _2d_iou_last = train_net(net=net, cfg=cfg, trial=trial,
epochs=cfg.train.num_epochs,
train_batch_size=cfg.train.train_batch_size,
val_batch_size=cfg.train.valid_batch_size,
learning_rate=cfg.train.learning_rate,
device=device,
val_percent=10.0 / 100,
img_scale = (cfg.base.image_shape, cfg.base.image_shape),
amp=False,
out_dir= cfg.base.best_valid_model_checkpoint)
_2d_dices.append(_2d_dice.item())
_2d_ious.append(_2d_iou.item())
_2d_dices_last.append(_2d_dice_last.item())
_2d_ious_last.append(_2d_iou_last.item())
print ("Average performance on best valid set")
print("2d dice {}, mean {}, std {}".format(_2d_dices, np.mean(_2d_dices), np.std(_2d_dices)))
print("2d iou {}, mean {}, std {}".format(_2d_ious, np.mean(_2d_ious), np.std(_2d_ious)))
print ("Average performance on the last epoch")
print("2d dice {}, mean {}, std {}".format(_2d_dices_last, np.mean(_2d_dices_last), np.std(_2d_dices_last)))
print("2d iou {}, mean {}, std {}".format(_2d_ious_last, np.mean(_2d_ious_last), np.std(_2d_ious_last)))
else:
for trial in range(3):
print ("----"*3)
if cfg.base.original_checkpoint == "scratch":
net = smp.Unet(encoder_name="resnet50", encoder_weights=None, in_channels=3, classes=num_classes)
else:
print ("Using pre-trained models from", cfg.base.original_checkpoint)
net = smp.Unet(encoder_name="resnet50", encoder_weights=cfg.base.original_checkpoint ,in_channels=3,
classes=num_classes)
net.to(device=device)
_2d_dice, _2d_iou = eval(cfg = cfg, out_dir = cfg.base.best_valid_model_checkpoint, net = net, device = device,
img_scale = (cfg.base.image_shape, cfg.base.image_shape), trial=trial)
_2d_dices.append(_2d_dice.item())
_2d_ious.append(_2d_iou.item())
print ("Average performance on best valid set")
print("2d dice {}, mean {}, std {}".format(_2d_dices, np.mean(_2d_dices), np.std(_2d_dices)))
print("2d iou {}, mean {}, std {}".format(_2d_ious, np.mean(_2d_ious), np.std(_2d_ious)))
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
sys.exit(0)