|
import argparse |
|
import logging |
|
import sys |
|
from pathlib import Path |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import wandb |
|
from torch import optim |
|
from torch.utils.data import DataLoader, random_split |
|
from tqdm import tqdm |
|
from torch.optim.lr_scheduler import ExponentialLR |
|
import os |
|
from dataloader.dataset_ete import SegmentationDataset_train, SegmentationDataset |
|
from utils.endtoend import dice_loss |
|
from utils.func import ( |
|
parse_config, |
|
load_config |
|
) |
|
from evaluate import evaluate, evaluate_3d_iou |
|
|
|
import segmentation_models_pytorch as smp |
|
import numpy as np |
|
import random |
|
num_classes = 2 |
|
np.random.seed(42) |
|
random.seed(42) |
|
torch.manual_seed(42) |
|
|
|
def train_net(net, |
|
cfg, |
|
trial, |
|
device, |
|
epochs: int = 30, |
|
train_batch_size: int = 128, |
|
val_batch_size: int = 128, |
|
learning_rate: float = 0.1, |
|
val_percent: float = 0.1, |
|
save_checkpoint: bool = True, |
|
img_scale = (224, 224), |
|
amp: bool = True, |
|
out_dir : str= './checkpoint/'): |
|
|
|
|
|
train_dir_img = Path(cfg.dataloader.train_dir_img) |
|
train_dir_mask = Path(cfg.dataloader.train_dir_mask) |
|
val_dir_img = Path(cfg.dataloader.valid_dir_img) |
|
val_dir_mask = Path(cfg.dataloader.valid_dir_mask) |
|
test_dir_img = Path(cfg.dataloader.test_dir_img) |
|
test_dir_mask = Path(cfg.dataloader.test_dir_mask) |
|
non_label_text = cfg.dataloader.non_label |
|
have_label_text = cfg.dataloader.have_label |
|
|
|
dir_checkpoint = Path(out_dir) |
|
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) |
|
|
|
train_dataset = SegmentationDataset_train(nonlabel_path= non_label_text, havelabel_path= have_label_text, dataset = cfg.base.dataset_name, scale= img_scale) |
|
val_dataset = SegmentationDataset(name_dataset=cfg.base.dataset_name, images_dir = val_dir_img, masks_dir = val_dir_mask, scale = img_scale) |
|
|
|
test_dataset = SegmentationDataset(name_dataset=cfg.base.dataset_name, images_dir = test_dir_img, masks_dir= test_dir_mask, scale = img_scale) |
|
|
|
n_train = len(train_dataset) |
|
n_val = len(val_dataset) |
|
|
|
|
|
loader_args = dict(num_workers=10, pin_memory=True) |
|
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, **loader_args) |
|
import time |
|
|
|
val_loader = DataLoader(val_dataset, shuffle=False, drop_last=True, batch_size=val_batch_size, **loader_args) |
|
test_loader = DataLoader(test_dataset, shuffle=False, drop_last=True, **loader_args) |
|
|
|
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must') |
|
experiment.config.update(dict(epochs=epochs, train_batch_size=train_batch_size, val_batch_size=val_batch_size, learning_rate=learning_rate, |
|
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, |
|
amp=amp)) |
|
|
|
logging.info(f'''Starting training: |
|
Epochs: {epochs} |
|
Train batch size: {train_batch_size} |
|
Val batch size: {val_batch_size} |
|
Learning rate: {learning_rate} |
|
Training size: {n_train} |
|
Validation size: {n_val} |
|
Checkpoints: {save_checkpoint} |
|
Device: {device.type} |
|
Images scaling: {img_scale} |
|
Mixed Precision: {amp} |
|
''') |
|
|
|
|
|
|
|
|
|
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(cfg.train.beta1, cfg.train.beta2), eps=1e-08, weight_decay=cfg.train.weight_decay) |
|
if cfg.train.scheduler: |
|
print("Use scheduler") |
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-05) |
|
|
|
|
|
|
|
|
|
|
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) |
|
criterion = nn.CrossEntropyLoss() |
|
global_step = 0 |
|
best_value = 0 |
|
|
|
for epoch in range(epochs): |
|
net.train() |
|
epoch_loss = 0 |
|
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: |
|
for batch in train_loader: |
|
images = batch['image'] |
|
true_masks = batch['mask_ete'] |
|
|
|
|
|
images = images.to(device=device, dtype=torch.float32) |
|
true_masks = true_masks.to(device=device, dtype=torch.long) |
|
|
|
with torch.cuda.amp.autocast(enabled=amp): |
|
masks_pred = net(images) |
|
loss = criterion(masks_pred, true_masks) \ |
|
+ dice_loss(F.softmax(masks_pred, dim=1).float(), |
|
F.one_hot(true_masks, num_classes).permute(0, 3, 1, 2).float(), |
|
multiclass=True) |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
grad_scaler.scale(loss).backward() |
|
clip_value = 1 |
|
torch.nn.utils.clip_grad_norm_(net.parameters(), clip_value) |
|
grad_scaler.step(optimizer) |
|
grad_scaler.update() |
|
|
|
pbar.update(images.shape[0]) |
|
global_step += 1 |
|
epoch_loss += loss.item() |
|
experiment.log({ |
|
'train loss': loss.item(), |
|
'step': global_step, |
|
'epoch': epoch |
|
}) |
|
pbar.set_postfix(**{'loss (batch)': loss.item()}) |
|
|
|
if cfg.train.scheduler: |
|
scheduler.step() |
|
|
|
if global_step % (n_train // (1 * train_batch_size)) == 0: |
|
val_dice_score, val_iou_score = evaluate(net, val_loader, device, 1) |
|
val_score = val_dice_score |
|
|
|
if (val_score > best_value): |
|
best_value = val_score |
|
logging.info("New best dice score: {} at epochs {}".format(best_value, epoch+1)) |
|
torch.save(net.state_dict(), str(dir_checkpoint/'checkpoint_{}_{}_best_{}.pth'.format(cfg.base.dataset_name, cfg.base.original_checkpoint, str(trial)))) |
|
|
|
logging.info('Validation Dice score: {}, IoU score {}'.format(val_dice_score, val_iou_score)) |
|
|
|
if epoch + 1 == epochs: |
|
val_dice_score, val_iou_score = evaluate(net, val_loader, device, 1) |
|
logging.info('Validation Dice score: {}, IoU score {}'.format(val_dice_score, val_iou_score)) |
|
|
|
if save_checkpoint: |
|
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1))) |
|
logging.info(f'Checkpoint {epoch + 1} saved!') |
|
|
|
if epoch > 0 and epoch != (epochs % 2 - 1) : |
|
os.remove( str(dir_checkpoint/'checkpoint_epoch{}.pth'.format(epoch))) |
|
logging.info("Evalutating on test set") |
|
logging.info("Loading best model on validation") |
|
net.load_state_dict(torch.load(str(dir_checkpoint/'checkpoint_{}_{}_best_{}.pth'.format(cfg.base.dataset_name, cfg.base.original_checkpoint, str(trial))))) |
|
test_dice, test_iou = evaluate(net, test_loader, device, 1) |
|
|
|
logging.info("Test dice score {}, IoU score {}".format(test_dice, test_iou)) |
|
|
|
logging.info("Loading model at last epochs %d" %epochs) |
|
net.load_state_dict(torch.load(str(dir_checkpoint/'checkpoint_epoch{}.pth'.format(epochs)))) |
|
test_dice_last, test_iou_last = evaluate(net, test_loader, device, 1) |
|
logging.info("Test dice score {}, IoU score {}".format(test_dice_last, test_iou_last)) |
|
|
|
return test_dice, test_iou, test_dice_last, test_iou_last |
|
|
|
def eval(cfg, out_dir, net, device, img_scale, trial): |
|
test_dir_img = Path(cfg.dataloader.test_dir_img) |
|
test_dir_mask = Path(cfg.dataloader.test_dir_mask) |
|
test_dataset = SegmentationDataset(name_dataset=cfg.base.dataset_name, images_dir = test_dir_img, masks_dir= test_dir_mask, scale = img_scale) |
|
loader_args = dict(num_workers=10, pin_memory=True) |
|
test_loader = DataLoader(test_dataset, shuffle=False, drop_last=True, **loader_args) |
|
dir_checkpoint = Path(out_dir) |
|
|
|
print("Trial", trial+1) |
|
logging.info("Evalutating on test set") |
|
logging.info("Loading best model on validation") |
|
net.load_state_dict(torch.load(str(dir_checkpoint/'checkpoint_{}_{}_best_{}.pth'.format(cfg.base.dataset_name, cfg.base.original_checkpoint, str(trial))))) |
|
test_dice, test_iou = evaluate(net, test_loader, device, 1) |
|
logging.info("Test dice score {}, IoU score {}".format(test_dice, test_iou)) |
|
return test_dice, test_iou |
|
|
|
|
|
def train_2d_R50(yml_args, cfg): |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
|
cuda_string = 'cuda:' + cfg.base.gpu_id |
|
device = torch.device(cuda_string if torch.cuda.is_available() else 'cpu') |
|
logging.info(f'Using device {device}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
_2d_dices = [] |
|
_2d_ious = [] |
|
_2d_dices_last = [] |
|
_2d_ious_last = [] |
|
|
|
if not yml_args.use_test_mode: |
|
for trial in range(3): |
|
print ("----"*3) |
|
if cfg.base.original_checkpoint == "scratch": |
|
net = smp.Unet(encoder_name="resnet50", encoder_weights=None, in_channels=3, classes=num_classes) |
|
else: |
|
print ("Using pre-trained models from", cfg.base.original_checkpoint) |
|
net = smp.Unet(encoder_name="resnet50", encoder_weights=cfg.base.original_checkpoint, |
|
in_channels=3, classes=num_classes) |
|
|
|
|
|
net.to(device=device) |
|
|
|
print("Trial", trial + 1) |
|
_2d_dice, _2d_iou, _2d_dice_last, _2d_iou_last = train_net(net=net, cfg=cfg, trial=trial, |
|
epochs=cfg.train.num_epochs, |
|
train_batch_size=cfg.train.train_batch_size, |
|
val_batch_size=cfg.train.valid_batch_size, |
|
learning_rate=cfg.train.learning_rate, |
|
device=device, |
|
val_percent=10.0 / 100, |
|
img_scale = (cfg.base.image_shape, cfg.base.image_shape), |
|
amp=False, |
|
out_dir= cfg.base.best_valid_model_checkpoint) |
|
_2d_dices.append(_2d_dice.item()) |
|
_2d_ious.append(_2d_iou.item()) |
|
_2d_dices_last.append(_2d_dice_last.item()) |
|
_2d_ious_last.append(_2d_iou_last.item()) |
|
|
|
print ("Average performance on best valid set") |
|
print("2d dice {}, mean {}, std {}".format(_2d_dices, np.mean(_2d_dices), np.std(_2d_dices))) |
|
print("2d iou {}, mean {}, std {}".format(_2d_ious, np.mean(_2d_ious), np.std(_2d_ious))) |
|
|
|
|
|
print ("Average performance on the last epoch") |
|
print("2d dice {}, mean {}, std {}".format(_2d_dices_last, np.mean(_2d_dices_last), np.std(_2d_dices_last))) |
|
print("2d iou {}, mean {}, std {}".format(_2d_ious_last, np.mean(_2d_ious_last), np.std(_2d_ious_last))) |
|
|
|
else: |
|
for trial in range(3): |
|
print ("----"*3) |
|
if cfg.base.original_checkpoint == "scratch": |
|
net = smp.Unet(encoder_name="resnet50", encoder_weights=None, in_channels=3, classes=num_classes) |
|
else: |
|
print ("Using pre-trained models from", cfg.base.original_checkpoint) |
|
net = smp.Unet(encoder_name="resnet50", encoder_weights=cfg.base.original_checkpoint ,in_channels=3, |
|
classes=num_classes) |
|
|
|
|
|
net.to(device=device) |
|
_2d_dice, _2d_iou = eval(cfg = cfg, out_dir = cfg.base.best_valid_model_checkpoint, net = net, device = device, |
|
img_scale = (cfg.base.image_shape, cfg.base.image_shape), trial=trial) |
|
_2d_dices.append(_2d_dice.item()) |
|
_2d_ious.append(_2d_iou.item()) |
|
print ("Average performance on best valid set") |
|
print("2d dice {}, mean {}, std {}".format(_2d_dices, np.mean(_2d_dices), np.std(_2d_dices))) |
|
print("2d iou {}, mean {}, std {}".format(_2d_ious, np.mean(_2d_ious), np.std(_2d_ious))) |
|
|
|
except KeyboardInterrupt: |
|
torch.save(net.state_dict(), 'INTERRUPTED.pth') |
|
logging.info('Saved interrupt') |
|
sys.exit(0) |