import torch from torch.utils.data import DataLoader import numpy as np import random from tqdm import tqdm import argparse from model import SegRoot from dataloader import StaticTrainDataset, TestDataset, TrainDataset, LoopSampler from utils import ( dice_score, init_weights, evaluate, get_ids, load_vgg16, set_random_seed, ) parser = argparse.ArgumentParser() parser.add_argument("--seed", default=42, type=int, help="set random seed") parser.add_argument("--width", default=8, type=int, help="width of SegRoot") parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot") parser.add_argument("--bs", default=64, type=int, help="batch size of dataloaders") parser.add_argument("--lr", default=1e-2, type=float, help="learning rate") parser.add_argument("--epochs", default=200, type=int, help="max epochs of training") parser.add_argument( "--verbose", default=5, type=int, help="intervals to save and validate model" ) parser.add_argument( "--dynamic", action="store_true", help="use dynamic sub-images during training" ) def train_one_epoch(model, train_iter, optimizer, device): model.train() for p in model.parameters(): p.requires_grad = True for x, y in train_iter: x, y = x.to(device), y.to(device) bs = x.shape[0] optimizer.zero_grad() y_pred = model(x) loss = 1 - dice_score(y, y_pred) loss = torch.sum(loss) / bs loss.backward() optimizer.step() if __name__ == "__main__": args = parser.parse_args() seed = args.seed bs = args.bs lr = args.lr width = args.width depth = args.depth epochs = args.epochs verbose = args.verbose # set random seed set_random_seed(seed) # define the device for training device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # get training ids train_ids, valid_ids, test_ids = get_ids(65) # define dataloaders if args.dynamic: train_data = TrainDataset(train_ids) train_iter = DataLoader( train_data, batch_size=bs, num_workers=6, sampler=LoopSampler ) else: train_data = StaticTrainDataset(train_ids) train_iter = DataLoader(train_data, batch_size=bs, num_workers=6, shuffle=True) train_tdata = TestDataset(train_ids) valid_tdata = TestDataset(valid_ids) test_tdata = TestDataset(test_ids) # define model model = SegRoot(width, depth).to(device) model = model.apply(init_weights) # define optimizer and lr_scheduler optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="max", factor=0.5, verbose=True, patience=5 ) print(f"Start training SegRoot-({width},{depth}))......") print(f"Random seed is {seed}, batch size is {bs}......") print(f"learning rate is {lr}, max epochs is {epochs}......") best_valid = float("-inf") for epoch in tqdm(range(epochs)): train_one_epoch(model, train_iter, optimizer, device) if epoch % verbose == 0: train_dice = evaluate(model, train_tdata, device) valid_dice = evaluate(model, valid_tdata, device) scheduler.step(valid_dice) print( "Epoch {:05d}, train dice: {:.4f}, valid dice: {:.4f}".format( epoch, train_dice, valid_dice ) ) if valid_dice > best_valid: best_valid = valid_dice test_dice = evaluate(model, test_tdata, device) print("New best validation, test dice: {:.4f}".format(test_dice)) torch.save( model.state_dict(), f"../weights/best_segroot-({args.width},{args.depth}).pt", )