TEDM-demo / trainers /train_datasetDM.py
anonymous
first commit without models
a2dba58
from argparse import Namespace
import os
from pathlib import Path
from dataloaders.JSRT import build_dataloaders
import torch
from trainers.utils import seed_everything, TensorboardLogger
from torch.cuda.amp import GradScaler
from trainers.train_baseline import train
from models.datasetDM_model import DatasetDM
def main(config: Namespace) -> None:
# adjust logdir to include experiment name
os.makedirs(config.log_dir, exist_ok=True)
print('Experiment folder: %s' % (config.log_dir))
# save config namespace into logdir
with open(config.log_dir / 'config.txt', 'w') as f:
for k, v in vars(config).items():
if type(v) not in [str, int, float, bool]:
f.write(f'{k}: {str(v)}\n')
else:
f.write(f'{k}: {v}\n')
# Random seed
seed_everything(config.seed)
model = DatasetDM(config)
if config.shared_weights_over_timesteps:
import torch.nn as nn
from einops.layers.torch import Rearrange
model.classifier = nn.Sequential(
Rearrange('b (step act) h w -> (b step) act h w', step=len(model.steps)),
nn.Conv2d(960, 128, 1),
nn.ReLU(),
nn.BatchNorm2d(128),
nn.Conv2d(128, 32, 1),
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Conv2d(32, 1, config.out_channels)
)
model = model.to(config.device)
model.train()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=config.lr, weight_decay=config.weight_decay) # , betas=config.adam_betas)
step = 0
scaler = GradScaler()
dataloaders = build_dataloaders(
config.data_dir,
config.img_size,
config.batch_size,
config.num_workers,
config.n_labelled_images
)
train_dl = dataloaders['train']
val_dl = dataloaders['val']
# Logger
logger = TensorboardLogger(config.log_dir, enabled=not config.debug)
train(config, model, optimizer, train_dl, val_dl, logger, scaler, step)