Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| from pathlib import Path | |
| import torch | |
| from torch import autocast | |
| from torch.cuda.amp import GradScaler | |
| from tqdm import tqdm | |
| from config import parser | |
| from dataloaders.CXR14 import build_dataloaders | |
| from models.diffusion_model import DiffusionModel | |
| from trainers.utils import (TensorboardLogger, compare_configs, sample_plot_image, | |
| seed_everything) | |
| def train(config, model, optimizer, train_loader, val_loader, logger, scaler, step): | |
| best_val_loss = float('inf') | |
| train_losses = [] | |
| pbar = tqdm(total=config.val_freq, desc='Training') | |
| while True: | |
| for x in train_loader: | |
| pbar.update(1) | |
| step += 1 | |
| x = x.to(config.device) | |
| # Forward pass | |
| optimizer.zero_grad() | |
| with autocast(device_type=config.device, enabled=config.mixed_precision): | |
| loss = model.train_step(x) | |
| scaler.scale(loss).backward() | |
| optimizer.step() | |
| train_losses.append(loss.item()) | |
| pbar.set_description(f'Training loss: {loss.item():.4f}') | |
| if step % config.log_freq == 0 or config.debug: | |
| avg_train_loss = sum(train_losses) / len(train_losses) | |
| print(f'Step {step} - Train loss: {avg_train_loss:.4f}') | |
| logger.log({'train/loss': avg_train_loss}, step=step) | |
| if step % config.val_freq == 0 or config.debug: | |
| val_results = validate(config, model, val_loader) | |
| logger.log(val_results, step=step) | |
| if val_results['val/loss'] < best_val_loss and not config.debug: | |
| print(f'Step {step} - New best validation loss: ' | |
| f'{val_results["val/loss"]:.4f}, saving model ' | |
| f'in {config.log_dir}') | |
| best_val_loss = val_results['val/loss'] | |
| save( | |
| model, | |
| optimizer, | |
| config, | |
| config.log_dir + '/best_model.pt', | |
| step | |
| ) | |
| if step >= config.max_steps or config.debug: | |
| return model | |
| def validate(config, model, val_loader): | |
| model.eval() | |
| losses = [] | |
| for i, x in tqdm(enumerate(val_loader), desc='Validating'): | |
| x = x.to(config.device) | |
| with autocast(device_type=config.device, enabled=config.mixed_precision): | |
| loss = model.train_step(x) | |
| losses.append(loss.item()) | |
| if i + 1 == config.max_val_steps or config.debug: | |
| break | |
| avg_loss = sum(losses) / len(losses) | |
| print(f'Validation loss: {avg_loss:.4f}') | |
| with autocast(device_type=config.device, enabled=config.mixed_precision): | |
| sampled_imgs = sample_plot_image( | |
| model, | |
| config.timesteps, | |
| config.img_size, | |
| config.n_sampled_imgs if not config.debug else 1, | |
| normalized=config.normalize, | |
| ) | |
| model.train() | |
| return { | |
| 'val/loss': avg_loss, | |
| 'val/sampled images': sampled_imgs | |
| } | |
| def save(model, optimizer, config, path, step): | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'config': config, | |
| 'step': step | |
| }, path) | |
| def load(new_config, path): | |
| checkpoint = torch.load(path, map_location=torch.device(new_config.device)) | |
| old_config = checkpoint['config'] | |
| compare_configs(old_config, new_config) | |
| model = DiffusionModel(old_config).to(new_config.device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=new_config.lr) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| step = checkpoint['step'] | |
| return model, optimizer, step | |
| def main(config): | |
| # adjust logdir to include experiment name | |
| config.log_dir = Path(config.log_dir).parent / "CXR14" / Path(config.log_dir).name | |
| os.makedirs(config.log_dir, exist_ok=True) | |
| # save config namespace into logdir | |
| with open(config.log_dir / 'config.txt', 'w') as f: | |
| for k, v in vars(config).items(): | |
| if type(v) not in [str, int, float, bool]: | |
| f.write(f'{k}: {str(v)}\n') | |
| else: | |
| f.write(f'{k}: {v}\n') | |
| # Random seed | |
| seed_everything(config.seed) | |
| # Init model and optimizer | |
| if config.resume_path is not None: | |
| print('Loading model from', config.resume_path) | |
| diffusion_model, optimizer, step = load(config, config.resume_path) | |
| else: | |
| diffusion_model = DiffusionModel(config) | |
| optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=config.lr) # , betas=config.adam_betas) | |
| step = 0 | |
| diffusion_model.to(config.device) | |
| diffusion_model.train() | |
| scaler = GradScaler() | |
| # Load data | |
| dataloaders = build_dataloaders( | |
| config.data_dir, | |
| config.img_size, | |
| config.batch_size, | |
| config.num_workers, | |
| ) | |
| train_dl = dataloaders['train'] | |
| val_dl = dataloaders['val'] | |
| # Logger | |
| logger = TensorboardLogger(config.log_dir, enabled=not config.debug) | |
| train(config, diffusion_model, optimizer, train_dl, val_dl, logger, scaler, step) | |