Andres Felipe Ruiz-Hurtado
initial
bc97962
import torch
from torch.utils.data import DataLoader
import numpy as np
import random
from tqdm import tqdm
import argparse
from model import SegRoot
from dataloader import StaticTrainDataset, TestDataset, TrainDataset, LoopSampler
from utils import (
dice_score,
init_weights,
evaluate,
get_ids,
load_vgg16,
set_random_seed,
)
parser = argparse.ArgumentParser()
parser.add_argument("--seed", default=42, type=int, help="set random seed")
parser.add_argument("--width", default=8, type=int, help="width of SegRoot")
parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot")
parser.add_argument("--bs", default=64, type=int, help="batch size of dataloaders")
parser.add_argument("--lr", default=1e-2, type=float, help="learning rate")
parser.add_argument("--epochs", default=200, type=int, help="max epochs of training")
parser.add_argument(
"--verbose", default=5, type=int, help="intervals to save and validate model"
)
parser.add_argument(
"--dynamic", action="store_true", help="use dynamic sub-images during training"
)
def train_one_epoch(model, train_iter, optimizer, device):
model.train()
for p in model.parameters():
p.requires_grad = True
for x, y in train_iter:
x, y = x.to(device), y.to(device)
bs = x.shape[0]
optimizer.zero_grad()
y_pred = model(x)
loss = 1 - dice_score(y, y_pred)
loss = torch.sum(loss) / bs
loss.backward()
optimizer.step()
if __name__ == "__main__":
args = parser.parse_args()
seed = args.seed
bs = args.bs
lr = args.lr
width = args.width
depth = args.depth
epochs = args.epochs
verbose = args.verbose
# set random seed
set_random_seed(seed)
# define the device for training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# get training ids
train_ids, valid_ids, test_ids = get_ids(65)
# define dataloaders
if args.dynamic:
train_data = TrainDataset(train_ids)
train_iter = DataLoader(
train_data, batch_size=bs, num_workers=6, sampler=LoopSampler
)
else:
train_data = StaticTrainDataset(train_ids)
train_iter = DataLoader(train_data, batch_size=bs, num_workers=6, shuffle=True)
train_tdata = TestDataset(train_ids)
valid_tdata = TestDataset(valid_ids)
test_tdata = TestDataset(test_ids)
# define model
model = SegRoot(width, depth).to(device)
model = model.apply(init_weights)
# define optimizer and lr_scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, verbose=True, patience=5
)
print(f"Start training SegRoot-({width},{depth}))......")
print(f"Random seed is {seed}, batch size is {bs}......")
print(f"learning rate is {lr}, max epochs is {epochs}......")
best_valid = float("-inf")
for epoch in tqdm(range(epochs)):
train_one_epoch(model, train_iter, optimizer, device)
if epoch % verbose == 0:
train_dice = evaluate(model, train_tdata, device)
valid_dice = evaluate(model, valid_tdata, device)
scheduler.step(valid_dice)
print(
"Epoch {:05d}, train dice: {:.4f}, valid dice: {:.4f}".format(
epoch, train_dice, valid_dice
)
)
if valid_dice > best_valid:
best_valid = valid_dice
test_dice = evaluate(model, test_tdata, device)
print("New best validation, test dice: {:.4f}".format(test_dice))
torch.save(
model.state_dict(),
f"../weights/best_segroot-({args.width},{args.depth}).pt",
)