TEDM-demo / trainers /train_base_diffusion.py
anonymous
first commit without models
a2dba58
import os
import torch
from torch import autocast
from torch.cuda.amp import GradScaler
from tqdm import tqdm
from dataloaders.JSRT import build_dataloaders as build_dataloaders_JSRT
from models.diffusion_model import DiffusionModel
from trainers.utils import (TensorboardLogger, compare_configs, sample_plot_image,
seed_everything)
def train(config, model, optimizer, train_loader, val_loader, logger, scaler, step):
best_val_loss = float('inf')
train_losses = []
if config.joint_training and config.experiment == 'joint_and_cond':
img_losses = []
seg_losses = []
pbar = tqdm(total=config.val_freq, desc='Training')
cond = None
while True:
for x, y in train_loader:
pbar.update(1)
step += 1
if config.experiment == "joint":
x = torch.concat([x,y], dim=1)
elif config.experiment == "conditional":
cond = x # condition on the image
cond = cond.to(config.device)
x = y # predict the segmentation
elif config.experiment == "joint_and_cond":
cond = y.to(config.device)
x = x.to(config.device)
# Forward pass
optimizer.zero_grad()
with autocast(device_type=config.device, enabled=config.mixed_precision):
loss = model.train_step(x, cond)
if config.joint_training and config.experiment == 'joint_and_cond':
img_loss, seg_loss = loss
loss = img_loss + seg_loss
img_losses.append(img_loss.item())
seg_losses.append(seg_loss.item())
scaler.scale(loss).backward()
optimizer.step()
train_losses.append(loss.item())
if config.joint_training and config.experiment == 'joint_and_cond':
pbar.set_description(f'Training loss: {loss.item():.4f} (img: {img_loss.item():.4f}, seg: {seg_loss.item():.4f})')
else:
pbar.set_description(f'Training loss: {loss.item():.4f}')
if step % config.log_freq == 0 or config.debug:
avg_train_loss = sum(train_losses) / len(train_losses)
print(f'Step {step} - Train loss: {avg_train_loss:.4f}')
logger.log({'train/loss': avg_train_loss}, step=step)
if config.joint_training and config.experiment == 'joint_and_cond':
avg_img_loss = sum(img_losses) / len(img_losses)
avg_seg_loss = sum(seg_losses) / len(seg_losses)
logger.log({'train_loss/img': avg_img_loss}, step=step)
logger.log({'train_loss/seg': avg_seg_loss}, step=step)
if step % config.val_freq == 0 or config.debug:
val_results = validate(config, model, val_loader)
logger.log(val_results, step=step)
if val_results['val/loss'] < best_val_loss and not config.debug:
print(f'Step {step} - New best validation loss: '
f'{val_results["val/loss"]:.4f}, saving model '
f'in {config.log_dir}')
best_val_loss = val_results['val/loss']
save(
model,
optimizer,
config,
config.log_dir / 'best_model.pt',
step
)
if step >= config.max_steps or config.debug:
return model
@torch.no_grad()
def validate(config, model, val_loader):
model.eval()
losses = []
if config.joint_training and config.experiment == 'joint_and_cond':
img_losses = []
seg_losses = []
cond=None
for i, (x, y) in tqdm(enumerate(val_loader), desc='Validating'):
if config.experiment == "joint":
x = torch.cat([x,y], dim=1)
elif config.experiment == "conditional":
cond = x # condition on the image
cond = cond.to(config.device)
x = y # predict the segmentation
elif config.experiment == "joint_and_cond":
cond = y.to(config.device)
x = x.to(config.device)
with autocast(device_type=config.device, enabled=config.mixed_precision):
if len(val_loader.dataset) > 1000:
# val set is too large to evaluate for each timestep, take random timesteps
loss = model.train_step(x, cond)
else:
loss = model.val_step(x, cond, t_steps=config.val_steps)
if config.joint_training and config.experiment == 'joint_and_cond':
img_loss, seg_loss = loss
loss = img_loss + seg_loss
img_losses.append(img_loss.item())
seg_losses.append(seg_loss.item())
losses.append(loss.item())
if i + 1 == config.max_val_steps or config.debug:
break
avg_loss = sum(losses) / len(losses)
if config.joint_training and config.experiment == 'joint_and_cond':
avg_img_loss = sum(img_losses) / len(img_losses)
avg_seg_loss = sum(seg_losses) / len(seg_losses)
print(f'Validation loss: {avg_loss:.4f} (img: {avg_img_loss:.4f}, seg: {avg_seg_loss:.4f})')
else:
print(f'Validation loss: {avg_loss:.4f}')
# Build visualisations
# select images for visualisation
if config.experiment == "conditional":
# note that there is no randomness on how the images are selected
# if there are enough images on the last validation batch, then we keep those
# otherwise, we take the first images from a new validation epoch batch
config.n_sampled_imgs = config.n_sampled_imgs if not config.debug else 1
N_missing = config.n_sampled_imgs - cond.shape[0]
iter_val_loader = iter(val_loader) # new validation epoch
while N_missing > 0: # are we missing images?
new_cond, _ = next(iter_val_loader)
cond = torch.cat([cond, new_cond], dim=0)
N_missing = config.n_sampled_imgs - cond.shape[0]
del iter_val_loader
if N_missing < 0: # do we have too many images?
cond = cond[:config.n_sampled_imgs]
with autocast(device_type=config.device, enabled=config.mixed_precision):
sampled_imgs = sample_plot_image(
model,
config.timesteps,
config.img_size,
config.n_sampled_imgs, # this defaults to 1 if in debug mode
config.channels if config.experiment != "joint_and_cond" else config.channels + config.out_channels, # out_channels stands for channels in segmentation
cond=cond,
normalized=config.normalize,
)
model.train()
return {
'val/loss': avg_loss,
'val/sampled images': sampled_imgs
}
def save(model, optimizer, config, path, step):
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': config,
'step': step
}, path)
def load(new_config, path):
checkpoint = torch.load(path, map_location=torch.device(new_config.device))
old_config = checkpoint['config']
compare_configs(old_config, new_config)
model = DiffusionModel(old_config)
model.to(new_config.device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer = torch.optim.Adam(model.parameters(), lr=new_config.lr)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
step = checkpoint['step']
return model, optimizer, step
def main(config):
# adjust logdir to include experiment name
os.makedirs(config.log_dir, exist_ok=True)
# save config namespace into logdir
with open(config.log_dir / 'config.txt', 'w') as f:
for k, v in vars(config).items():
if type(v) not in [str, int, float, bool]:
f.write(f'{k}: {str(v)}\n')
else:
f.write(f'{k}: {v}\n')
# Random seed
seed_everything(config.seed)
# Init model and optimizer
if config.resume_path is not None:
print('Loading model from', config.resume_path)
diffusion_model, optimizer, step = load(config, config.resume_path)
else:
if config.experiment in ["img_only", "joint", "conditional"]:
diffusion_model = DiffusionModel(config)
else:
raise ValueError(f"Unknown experiment: {config.experiment}")
optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=config.lr, weight_decay=config.weight_decay) # , betas=config.adam_betas)
step = 0
diffusion_model.to(config.device)
diffusion_model.train()
scaler = GradScaler()
# Load data
if config.dataset == "JSRT":
build_dataloaders = build_dataloaders_JSRT
else:
raise ValueError(f"Unknown dataset: {config.dataset}")
dataloaders = build_dataloaders(
config.data_dir,
config.img_size,
config.batch_size,
config.num_workers,
config.n_labelled_images,
)
train_dl = dataloaders['train']
val_dl = dataloaders['val']
# Logger
logger = TensorboardLogger(config.log_dir, enabled=not config.debug)
train(config, diffusion_model, optimizer, train_dl, val_dl, logger, scaler, step)