File size: 2,028 Bytes
a2dba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from argparse import Namespace
import os
from pathlib import Path
from dataloaders.JSRT import build_dataloaders
import torch
from trainers.utils import seed_everything, TensorboardLogger
from torch.cuda.amp import GradScaler
from trainers.train_baseline import train
from models.datasetDM_model import DatasetDM



def main(config: Namespace) -> None:
    # adjust logdir to include experiment name
    os.makedirs(config.log_dir, exist_ok=True)
    print('Experiment folder: %s' % (config.log_dir))

    # save config namespace into logdir
    with open(config.log_dir / 'config.txt', 'w') as f:
        for k, v in vars(config).items():
            if type(v) not in [str, int, float, bool]:
                f.write(f'{k}: {str(v)}\n')
            else:
                f.write(f'{k}: {v}\n')

    # Random seed
    seed_everything(config.seed)

    model = DatasetDM(config)
    if config.shared_weights_over_timesteps:
        import torch.nn as nn
        from einops.layers.torch import Rearrange
        model.classifier = nn.Sequential(
            Rearrange('b (step act) h w -> (b step) act h w', step=len(model.steps)),
            nn.Conv2d(960, 128, 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 32, 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 1, config.out_channels)
            )
    model = model.to(config.device)
    model.train()

    optimizer = torch.optim.Adam(model.classifier.parameters(), lr=config.lr, weight_decay=config.weight_decay)  # , betas=config.adam_betas)
    step = 0

    scaler = GradScaler()

    dataloaders = build_dataloaders(
        config.data_dir,
        config.img_size,
        config.batch_size,
        config.num_workers,
        config.n_labelled_images
    )
    train_dl = dataloaders['train']
    val_dl = dataloaders['val']

    # Logger
    logger = TensorboardLogger(config.log_dir, enabled=not config.debug)

    train(config, model, optimizer, train_dl, val_dl, logger, scaler, step)