import os import torch from torch import autocast from torch.cuda.amp import GradScaler from tqdm import tqdm from dataloaders.JSRT import build_dataloaders as build_dataloaders_JSRT from models.diffusion_model import DiffusionModel from trainers.utils import (TensorboardLogger, compare_configs, sample_plot_image, seed_everything) def train(config, model, optimizer, train_loader, val_loader, logger, scaler, step): best_val_loss = float('inf') train_losses = [] if config.joint_training and config.experiment == 'joint_and_cond': img_losses = [] seg_losses = [] pbar = tqdm(total=config.val_freq, desc='Training') cond = None while True: for x, y in train_loader: pbar.update(1) step += 1 if config.experiment == "joint": x = torch.concat([x,y], dim=1) elif config.experiment == "conditional": cond = x # condition on the image cond = cond.to(config.device) x = y # predict the segmentation elif config.experiment == "joint_and_cond": cond = y.to(config.device) x = x.to(config.device) # Forward pass optimizer.zero_grad() with autocast(device_type=config.device, enabled=config.mixed_precision): loss = model.train_step(x, cond) if config.joint_training and config.experiment == 'joint_and_cond': img_loss, seg_loss = loss loss = img_loss + seg_loss img_losses.append(img_loss.item()) seg_losses.append(seg_loss.item()) scaler.scale(loss).backward() optimizer.step() train_losses.append(loss.item()) if config.joint_training and config.experiment == 'joint_and_cond': pbar.set_description(f'Training loss: {loss.item():.4f} (img: {img_loss.item():.4f}, seg: {seg_loss.item():.4f})') else: pbar.set_description(f'Training loss: {loss.item():.4f}') if step % config.log_freq == 0 or config.debug: avg_train_loss = sum(train_losses) / len(train_losses) print(f'Step {step} - Train loss: {avg_train_loss:.4f}') logger.log({'train/loss': avg_train_loss}, step=step) if config.joint_training and config.experiment == 'joint_and_cond': avg_img_loss = sum(img_losses) / len(img_losses) avg_seg_loss = sum(seg_losses) / len(seg_losses) logger.log({'train_loss/img': avg_img_loss}, step=step) logger.log({'train_loss/seg': avg_seg_loss}, step=step) if step % config.val_freq == 0 or config.debug: val_results = validate(config, model, val_loader) logger.log(val_results, step=step) if val_results['val/loss'] < best_val_loss and not config.debug: print(f'Step {step} - New best validation loss: ' f'{val_results["val/loss"]:.4f}, saving model ' f'in {config.log_dir}') best_val_loss = val_results['val/loss'] save( model, optimizer, config, config.log_dir / 'best_model.pt', step ) if step >= config.max_steps or config.debug: return model @torch.no_grad() def validate(config, model, val_loader): model.eval() losses = [] if config.joint_training and config.experiment == 'joint_and_cond': img_losses = [] seg_losses = [] cond=None for i, (x, y) in tqdm(enumerate(val_loader), desc='Validating'): if config.experiment == "joint": x = torch.cat([x,y], dim=1) elif config.experiment == "conditional": cond = x # condition on the image cond = cond.to(config.device) x = y # predict the segmentation elif config.experiment == "joint_and_cond": cond = y.to(config.device) x = x.to(config.device) with autocast(device_type=config.device, enabled=config.mixed_precision): if len(val_loader.dataset) > 1000: # val set is too large to evaluate for each timestep, take random timesteps loss = model.train_step(x, cond) else: loss = model.val_step(x, cond, t_steps=config.val_steps) if config.joint_training and config.experiment == 'joint_and_cond': img_loss, seg_loss = loss loss = img_loss + seg_loss img_losses.append(img_loss.item()) seg_losses.append(seg_loss.item()) losses.append(loss.item()) if i + 1 == config.max_val_steps or config.debug: break avg_loss = sum(losses) / len(losses) if config.joint_training and config.experiment == 'joint_and_cond': avg_img_loss = sum(img_losses) / len(img_losses) avg_seg_loss = sum(seg_losses) / len(seg_losses) print(f'Validation loss: {avg_loss:.4f} (img: {avg_img_loss:.4f}, seg: {avg_seg_loss:.4f})') else: print(f'Validation loss: {avg_loss:.4f}') # Build visualisations # select images for visualisation if config.experiment == "conditional": # note that there is no randomness on how the images are selected # if there are enough images on the last validation batch, then we keep those # otherwise, we take the first images from a new validation epoch batch config.n_sampled_imgs = config.n_sampled_imgs if not config.debug else 1 N_missing = config.n_sampled_imgs - cond.shape[0] iter_val_loader = iter(val_loader) # new validation epoch while N_missing > 0: # are we missing images? new_cond, _ = next(iter_val_loader) cond = torch.cat([cond, new_cond], dim=0) N_missing = config.n_sampled_imgs - cond.shape[0] del iter_val_loader if N_missing < 0: # do we have too many images? cond = cond[:config.n_sampled_imgs] with autocast(device_type=config.device, enabled=config.mixed_precision): sampled_imgs = sample_plot_image( model, config.timesteps, config.img_size, config.n_sampled_imgs, # this defaults to 1 if in debug mode config.channels if config.experiment != "joint_and_cond" else config.channels + config.out_channels, # out_channels stands for channels in segmentation cond=cond, normalized=config.normalize, ) model.train() return { 'val/loss': avg_loss, 'val/sampled images': sampled_imgs } def save(model, optimizer, config, path, step): torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'config': config, 'step': step }, path) def load(new_config, path): checkpoint = torch.load(path, map_location=torch.device(new_config.device)) old_config = checkpoint['config'] compare_configs(old_config, new_config) model = DiffusionModel(old_config) model.to(new_config.device) model.load_state_dict(checkpoint['model_state_dict']) optimizer = torch.optim.Adam(model.parameters(), lr=new_config.lr) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) step = checkpoint['step'] return model, optimizer, step def main(config): # adjust logdir to include experiment name os.makedirs(config.log_dir, exist_ok=True) # save config namespace into logdir with open(config.log_dir / 'config.txt', 'w') as f: for k, v in vars(config).items(): if type(v) not in [str, int, float, bool]: f.write(f'{k}: {str(v)}\n') else: f.write(f'{k}: {v}\n') # Random seed seed_everything(config.seed) # Init model and optimizer if config.resume_path is not None: print('Loading model from', config.resume_path) diffusion_model, optimizer, step = load(config, config.resume_path) else: if config.experiment in ["img_only", "joint", "conditional"]: diffusion_model = DiffusionModel(config) else: raise ValueError(f"Unknown experiment: {config.experiment}") optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=config.lr, weight_decay=config.weight_decay) # , betas=config.adam_betas) step = 0 diffusion_model.to(config.device) diffusion_model.train() scaler = GradScaler() # Load data if config.dataset == "JSRT": build_dataloaders = build_dataloaders_JSRT else: raise ValueError(f"Unknown dataset: {config.dataset}") dataloaders = build_dataloaders( config.data_dir, config.img_size, config.batch_size, config.num_workers, config.n_labelled_images, ) train_dl = dataloaders['train'] val_dl = dataloaders['val'] # Logger logger = TensorboardLogger(config.log_dir, enabled=not config.debug) train(config, diffusion_model, optimizer, train_dl, val_dl, logger, scaler, step)