File size: 3,869 Bytes
bc97962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch.utils.data import DataLoader
import numpy as np
import random
from tqdm import tqdm
import argparse

from model import SegRoot
from dataloader import StaticTrainDataset, TestDataset, TrainDataset, LoopSampler
from utils import (
    dice_score,
    init_weights,
    evaluate,
    get_ids,
    load_vgg16,
    set_random_seed,
)


parser = argparse.ArgumentParser()
parser.add_argument("--seed", default=42, type=int, help="set random seed")
parser.add_argument("--width", default=8, type=int, help="width of SegRoot")
parser.add_argument("--depth", default=5, type=int, help="depth of SegRoot")
parser.add_argument("--bs", default=64, type=int, help="batch size of dataloaders")
parser.add_argument("--lr", default=1e-2, type=float, help="learning rate")
parser.add_argument("--epochs", default=200, type=int, help="max epochs of training")
parser.add_argument(
    "--verbose", default=5, type=int, help="intervals to save and validate model"
)
parser.add_argument(
    "--dynamic", action="store_true", help="use dynamic sub-images during training"
)


def train_one_epoch(model, train_iter, optimizer, device):
    model.train()
    for p in model.parameters():
        p.requires_grad = True
    for x, y in train_iter:
        x, y = x.to(device), y.to(device)
        bs = x.shape[0]
        optimizer.zero_grad()
        y_pred = model(x)
        loss = 1 - dice_score(y, y_pred)
        loss = torch.sum(loss) / bs
        loss.backward()
        optimizer.step()


if __name__ == "__main__":
    args = parser.parse_args()
    seed = args.seed
    bs = args.bs
    lr = args.lr
    width = args.width
    depth = args.depth
    epochs = args.epochs
    verbose = args.verbose

    # set random seed
    set_random_seed(seed)
    # define the device for training
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # get training ids
    train_ids, valid_ids, test_ids = get_ids(65)
    # define dataloaders
    if args.dynamic:
        train_data = TrainDataset(train_ids)
        train_iter = DataLoader(
            train_data, batch_size=bs, num_workers=6, sampler=LoopSampler
        )
    else:
        train_data = StaticTrainDataset(train_ids)
        train_iter = DataLoader(train_data, batch_size=bs, num_workers=6, shuffle=True)

    train_tdata = TestDataset(train_ids)
    valid_tdata = TestDataset(valid_ids)
    test_tdata = TestDataset(test_ids)

    # define model
    model = SegRoot(width, depth).to(device)
    model = model.apply(init_weights)

    # define optimizer and lr_scheduler
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", factor=0.5, verbose=True, patience=5
    )

    print(f"Start training SegRoot-({width},{depth}))......")
    print(f"Random seed is {seed}, batch size is {bs}......")
    print(f"learning rate is {lr}, max epochs is {epochs}......")
    best_valid = float("-inf")
    for epoch in tqdm(range(epochs)):
        train_one_epoch(model, train_iter, optimizer, device)
        if epoch % verbose == 0:
            train_dice = evaluate(model, train_tdata, device)
            valid_dice = evaluate(model, valid_tdata, device)
            scheduler.step(valid_dice)
            print(
                "Epoch {:05d}, train dice: {:.4f}, valid dice: {:.4f}".format(
                    epoch, train_dice, valid_dice
                )
            )
            if valid_dice > best_valid:
                best_valid = valid_dice
                test_dice = evaluate(model, test_tdata, device)
                print("New best validation, test dice: {:.4f}".format(test_dice))
                torch.save(
                    model.state_dict(),
                    f"../weights/best_segroot-({args.width},{args.depth}).pt",
                )