|
import torch |
|
import os |
|
from monai.utils import set_determinism |
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping |
|
import os |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
from trainer import BRATS |
|
from dataset.utils import get_loader |
|
import pytorch_lightning as pl |
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
set_determinism(seed=0) |
|
|
|
os.system('cls||clear') |
|
print("Training ...") |
|
|
|
data_dir = "/app/brats_2021_task1" |
|
json_list = "/app/info.json" |
|
roi = (128, 128, 128) |
|
batch_size = 1 |
|
fold = 1 |
|
max_epochs = 500 |
|
val_every = 10 |
|
train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=1, test_size=0.2) |
|
print("Done initialize dataloader !! ") |
|
|
|
model = BRATS(use_VAE = True, train_loader = train_loader,val_loader = val_loader, test_loader=test_loader ) |
|
checkpoint_callback = ModelCheckpoint( |
|
monitor='val/MeanDiceScore', |
|
dirpath='./checkpoints/{}'.format("SegTransVAE"), |
|
filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}', |
|
save_top_k=3, |
|
mode='max', |
|
save_last= True, |
|
auto_insert_metric_name=False |
|
) |
|
early_stop_callback = EarlyStopping( |
|
monitor='val/MeanDiceScore', |
|
min_delta=0.0001, |
|
patience=15, |
|
verbose=False, |
|
mode='max' |
|
) |
|
tensorboardlogger = TensorBoardLogger( |
|
'logs', |
|
name = "SegTransVAE", |
|
default_hp_metric = None |
|
) |
|
trainer = pl.Trainer( |
|
|
|
|
|
devices = [0], |
|
precision=16, |
|
max_epochs = max_epochs, |
|
enable_progress_bar=True, |
|
callbacks=[checkpoint_callback, early_stop_callback], |
|
|
|
num_sanity_val_steps=1, |
|
logger = tensorboardlogger, |
|
check_val_every_n_epoch = 10, |
|
|
|
|
|
) |
|
|
|
trainer.fit(model) |
|
|
|
|
|
|
|
|