TEDM-demo / trainers /datasetDM_per_step.py
anonymous
first commit without models
a2dba58
from argparse import Namespace
import os
from pathlib import Path
from dataloaders.JSRT import build_dataloaders
import torch
from tqdm.auto import tqdm
from trainers.utils import seed_everything, TensorboardLogger
from torch.cuda.amp import GradScaler
from torch import Tensor, nn
from typing import Dict, Optional
from trainers.train_baseline import train
from models.datasetDM_model import DatasetDM
from einops import repeat
from einops.layers.torch import Rearrange
class ModDatasetDM(DatasetDM):
# the idea here is to pool info per timestep,
# so that we can then use the aggregate for feature importance
def __init__(self, args: Namespace) -> None:
super().__init__(args)
self.mean = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
self.mean_squared = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
self.std = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
self.classifier = nn.Conv2d(len(self.steps) * 960, 1, 1)
def forward(self, x: Tensor) -> Tensor:
features = self.extract_features(x).to(x.device)
out = (features - self.mean ) / self.std
out = self.classifier(features)
return out
class OneStepPredDatasetDM(DatasetDM):
# the idea here is to pool info per timestep,
# so that we can then use the aggregate for feature importance
def __init__(self, args: Namespace) -> None:
super().__init__(args)
self.mean = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
self.mean_squared = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
self.std = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False)
self.classifier = nn.Sequential(
Rearrange('b (step act) h w -> (b step) act h w', step=len(self.steps)),
nn.Conv2d(960, 128, 1),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.Conv2d(128, 32, 1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 1, args.out_channels)
)
def forward(self, x: Tensor) -> Tensor:
features = self.extract_features(x).to(x.device)
out = (features - self.mean ) / self.std
out = self.classifier(features)
return out
def main(config: Namespace) -> None:
# adjust logdir to include experiment name
os.makedirs(config.log_dir, exist_ok=True)
print('Experiment folder: %s' % (config.log_dir))
# save config namespace into logdir
with open(config.log_dir / 'config.txt', 'w') as f:
for k, v in vars(config).items():
if type(v) not in [str, int, float, bool]:
f.write(f'{k}: {str(v)}\n')
else:
f.write(f'{k}: {v}\n')
# Random seed
seed_everything(config.seed)
model = ModDatasetDM(config)
model = model.to(config.device)
model.train()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=config.lr, weight_decay=config.weight_decay) # , betas=config.adam_betas)
step = 0
scaler = GradScaler()
dataloaders = build_dataloaders(
config.data_dir,
config.img_size,
config.batch_size,
config.num_workers,
config.n_labelled_images
)
train_dl = dataloaders['train']
val_dl = dataloaders['val']
# Logger
logger = TensorboardLogger(config.log_dir, enabled=not config.debug)
# do a loop to calculate mean and variance of the features
# then use those to normalize the features
model.to(config.device)
for x, _ in tqdm(train_dl, desc="Calculating mean and variance"):
x = x.to(config.device)
features = model.extract_features(x)
model.mean += features.sum(dim=0)
model.mean_squared += (features ** 2).sum(dim=0)
model.mean = model.mean / len(train_dl.dataset)
model.std = (model.mean_squared / len(train_dl.dataset) - model.mean ** 2).sqrt() + 1e-6
model.mean = model.mean.to(config.device)
model.std = model.std.to(config.device)
train(config, model, optimizer, train_dl, val_dl, logger, scaler, step)