Spaces:
Runtime error
Runtime error
from argparse import Namespace | |
import os | |
from pathlib import Path | |
from dataloaders.JSRT import build_dataloaders | |
import torch | |
from tqdm.auto import tqdm | |
from trainers.utils import seed_everything, TensorboardLogger | |
from torch.cuda.amp import GradScaler | |
from torch import Tensor, nn | |
from typing import Dict, Optional | |
from trainers.train_baseline import train | |
from models.datasetDM_model import DatasetDM | |
from einops import repeat | |
from einops.layers.torch import Rearrange | |
class ModDatasetDM(DatasetDM): | |
# the idea here is to pool info per timestep, | |
# so that we can then use the aggregate for feature importance | |
def __init__(self, args: Namespace) -> None: | |
super().__init__(args) | |
self.mean = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False) | |
self.mean_squared = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False) | |
self.std = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False) | |
self.classifier = nn.Conv2d(len(self.steps) * 960, 1, 1) | |
def forward(self, x: Tensor) -> Tensor: | |
features = self.extract_features(x).to(x.device) | |
out = (features - self.mean ) / self.std | |
out = self.classifier(features) | |
return out | |
class OneStepPredDatasetDM(DatasetDM): | |
# the idea here is to pool info per timestep, | |
# so that we can then use the aggregate for feature importance | |
def __init__(self, args: Namespace) -> None: | |
super().__init__(args) | |
self.mean = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False) | |
self.mean_squared = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False) | |
self.std = torch.zeros(len(self.steps) * 960, args.img_size, args.img_size, requires_grad=False) | |
self.classifier = nn.Sequential( | |
Rearrange('b (step act) h w -> (b step) act h w', step=len(self.steps)), | |
nn.Conv2d(960, 128, 1), | |
nn.ReLU(), | |
nn.BatchNorm2d(128), | |
nn.Conv2d(128, 32, 1), | |
nn.ReLU(), | |
nn.BatchNorm2d(32), | |
nn.Conv2d(32, 1, args.out_channels) | |
) | |
def forward(self, x: Tensor) -> Tensor: | |
features = self.extract_features(x).to(x.device) | |
out = (features - self.mean ) / self.std | |
out = self.classifier(features) | |
return out | |
def main(config: Namespace) -> None: | |
# adjust logdir to include experiment name | |
os.makedirs(config.log_dir, exist_ok=True) | |
print('Experiment folder: %s' % (config.log_dir)) | |
# save config namespace into logdir | |
with open(config.log_dir / 'config.txt', 'w') as f: | |
for k, v in vars(config).items(): | |
if type(v) not in [str, int, float, bool]: | |
f.write(f'{k}: {str(v)}\n') | |
else: | |
f.write(f'{k}: {v}\n') | |
# Random seed | |
seed_everything(config.seed) | |
model = ModDatasetDM(config) | |
model = model.to(config.device) | |
model.train() | |
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=config.lr, weight_decay=config.weight_decay) # , betas=config.adam_betas) | |
step = 0 | |
scaler = GradScaler() | |
dataloaders = build_dataloaders( | |
config.data_dir, | |
config.img_size, | |
config.batch_size, | |
config.num_workers, | |
config.n_labelled_images | |
) | |
train_dl = dataloaders['train'] | |
val_dl = dataloaders['val'] | |
# Logger | |
logger = TensorboardLogger(config.log_dir, enabled=not config.debug) | |
# do a loop to calculate mean and variance of the features | |
# then use those to normalize the features | |
model.to(config.device) | |
for x, _ in tqdm(train_dl, desc="Calculating mean and variance"): | |
x = x.to(config.device) | |
features = model.extract_features(x) | |
model.mean += features.sum(dim=0) | |
model.mean_squared += (features ** 2).sum(dim=0) | |
model.mean = model.mean / len(train_dl.dataset) | |
model.std = (model.mean_squared / len(train_dl.dataset) - model.mean ** 2).sqrt() + 1e-6 | |
model.mean = model.mean.to(config.device) | |
model.std = model.std.to(config.device) | |
train(config, model, optimizer, train_dl, val_dl, logger, scaler, step) | |