File size: 4,309 Bytes
a2dba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from argparse import Namespace
import os
from pathlib import Path
from dataloaders.JSRT import build_dataloaders
import torch
from tqdm.auto import tqdm
from trainers.utils import seed_everything, TensorboardLogger
from torch.cuda.amp import GradScaler
from torch import Tensor, nn
from typing import Dict, Optional
from trainers.train_baseline import train
from models.datasetDM_model import DatasetDM
from einops import repeat
from einops.layers.torch import Rearrange


class ModDatasetDM(DatasetDM):
    # the idea here is to pool info per timestep, 
    # so that we can then use the aggregate for feature importance

    def __init__(self, args: Namespace) -> None:
        super().__init__(args)
        self.mean = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
        self.mean_squared = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
        self.std = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
        self.classifier = nn.Conv2d(len(self.steps) * 960, 1, 1)
    
    def forward(self, x: Tensor) -> Tensor:
        features = self.extract_features(x).to(x.device)
        out = (features - self.mean ) / self.std
        out = self.classifier(features)
        return out
    
class OneStepPredDatasetDM(DatasetDM):
    # the idea here is to pool info per timestep, 
    # so that we can then use the aggregate for feature importance

    def __init__(self, args: Namespace) -> None:
        super().__init__(args)
        self.mean = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
        self.mean_squared = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
        self.std = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
        self.classifier = nn.Sequential(
            Rearrange('b (step act) h w -> (b step) act h w', step=len(self.steps)),
            nn.Conv2d(960, 128, 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 32, 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 1, args.out_channels)
            )
    

    def forward(self, x: Tensor) -> Tensor:
        features = self.extract_features(x).to(x.device)
        out = (features - self.mean ) / self.std
        out = self.classifier(features)
        return out


def main(config: Namespace) -> None:
    # adjust logdir to include experiment name
    os.makedirs(config.log_dir, exist_ok=True)
    print('Experiment folder: %s' % (config.log_dir))

    # save config namespace into logdir
    with open(config.log_dir / 'config.txt', 'w') as f:
        for k, v in vars(config).items():
            if type(v) not in [str, int, float, bool]:
                f.write(f'{k}: {str(v)}\n')
            else:
                f.write(f'{k}: {v}\n')

    # Random seed
    seed_everything(config.seed)

    model = ModDatasetDM(config)
    model = model.to(config.device)
    model.train()

    optimizer = torch.optim.Adam(model.classifier.parameters(), lr=config.lr, weight_decay=config.weight_decay)  # , betas=config.adam_betas)
    step = 0

    scaler = GradScaler()

    dataloaders = build_dataloaders(
        config.data_dir,
        config.img_size,
        config.batch_size,
        config.num_workers,
        config.n_labelled_images
    )
    train_dl = dataloaders['train']
    val_dl = dataloaders['val']

    # Logger
    logger = TensorboardLogger(config.log_dir, enabled=not config.debug)


    # do a loop to calculate mean and variance of the features
    # then use those to normalize the features
    model.to(config.device)
    for x, _ in tqdm(train_dl, desc="Calculating mean and variance"):
        x = x.to(config.device)
        features = model.extract_features(x)
        model.mean += features.sum(dim=0)
        model.mean_squared += (features ** 2).sum(dim=0)
    model.mean = model.mean / len(train_dl.dataset)
    model.std = (model.mean_squared / len(train_dl.dataset) - model.mean ** 2).sqrt() + 1e-6

    model.mean = model.mean.to(config.device)
    model.std = model.std.to(config.device)
        
    train(config, model, optimizer, train_dl, val_dl, logger, scaler, step)