Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from pathlib import Path | |
import torch | |
from torch import autocast | |
from torch.cuda.amp import GradScaler | |
from tqdm import tqdm | |
from config import parser | |
from dataloaders.CXR14 import build_dataloaders | |
from models.diffusion_model import DiffusionModel | |
from trainers.utils import (TensorboardLogger, compare_configs, sample_plot_image, | |
seed_everything) | |
def train(config, model, optimizer, train_loader, val_loader, logger, scaler, step): | |
best_val_loss = float('inf') | |
train_losses = [] | |
pbar = tqdm(total=config.val_freq, desc='Training') | |
while True: | |
for x in train_loader: | |
pbar.update(1) | |
step += 1 | |
x = x.to(config.device) | |
# Forward pass | |
optimizer.zero_grad() | |
with autocast(device_type=config.device, enabled=config.mixed_precision): | |
loss = model.train_step(x) | |
scaler.scale(loss).backward() | |
optimizer.step() | |
train_losses.append(loss.item()) | |
pbar.set_description(f'Training loss: {loss.item():.4f}') | |
if step % config.log_freq == 0 or config.debug: | |
avg_train_loss = sum(train_losses) / len(train_losses) | |
print(f'Step {step} - Train loss: {avg_train_loss:.4f}') | |
logger.log({'train/loss': avg_train_loss}, step=step) | |
if step % config.val_freq == 0 or config.debug: | |
val_results = validate(config, model, val_loader) | |
logger.log(val_results, step=step) | |
if val_results['val/loss'] < best_val_loss and not config.debug: | |
print(f'Step {step} - New best validation loss: ' | |
f'{val_results["val/loss"]:.4f}, saving model ' | |
f'in {config.log_dir}') | |
best_val_loss = val_results['val/loss'] | |
save( | |
model, | |
optimizer, | |
config, | |
config.log_dir + '/best_model.pt', | |
step | |
) | |
if step >= config.max_steps or config.debug: | |
return model | |
def validate(config, model, val_loader): | |
model.eval() | |
losses = [] | |
for i, x in tqdm(enumerate(val_loader), desc='Validating'): | |
x = x.to(config.device) | |
with autocast(device_type=config.device, enabled=config.mixed_precision): | |
loss = model.train_step(x) | |
losses.append(loss.item()) | |
if i + 1 == config.max_val_steps or config.debug: | |
break | |
avg_loss = sum(losses) / len(losses) | |
print(f'Validation loss: {avg_loss:.4f}') | |
with autocast(device_type=config.device, enabled=config.mixed_precision): | |
sampled_imgs = sample_plot_image( | |
model, | |
config.timesteps, | |
config.img_size, | |
config.n_sampled_imgs if not config.debug else 1, | |
normalized=config.normalize, | |
) | |
model.train() | |
return { | |
'val/loss': avg_loss, | |
'val/sampled images': sampled_imgs | |
} | |
def save(model, optimizer, config, path, step): | |
torch.save({ | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'config': config, | |
'step': step | |
}, path) | |
def load(new_config, path): | |
checkpoint = torch.load(path, map_location=torch.device(new_config.device)) | |
old_config = checkpoint['config'] | |
compare_configs(old_config, new_config) | |
model = DiffusionModel(old_config).to(new_config.device) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
optimizer = torch.optim.Adam(model.parameters(), lr=new_config.lr) | |
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
step = checkpoint['step'] | |
return model, optimizer, step | |
def main(config): | |
# adjust logdir to include experiment name | |
config.log_dir = Path(config.log_dir).parent / "CXR14" / Path(config.log_dir).name | |
os.makedirs(config.log_dir, exist_ok=True) | |
# save config namespace into logdir | |
with open(config.log_dir / 'config.txt', 'w') as f: | |
for k, v in vars(config).items(): | |
if type(v) not in [str, int, float, bool]: | |
f.write(f'{k}: {str(v)}\n') | |
else: | |
f.write(f'{k}: {v}\n') | |
# Random seed | |
seed_everything(config.seed) | |
# Init model and optimizer | |
if config.resume_path is not None: | |
print('Loading model from', config.resume_path) | |
diffusion_model, optimizer, step = load(config, config.resume_path) | |
else: | |
diffusion_model = DiffusionModel(config) | |
optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=config.lr) # , betas=config.adam_betas) | |
step = 0 | |
diffusion_model.to(config.device) | |
diffusion_model.train() | |
scaler = GradScaler() | |
# Load data | |
dataloaders = build_dataloaders( | |
config.data_dir, | |
config.img_size, | |
config.batch_size, | |
config.num_workers, | |
) | |
train_dl = dataloaders['train'] | |
val_dl = dataloaders['val'] | |
# Logger | |
logger = TensorboardLogger(config.log_dir, enabled=not config.debug) | |
train(config, diffusion_model, optimizer, train_dl, val_dl, logger, scaler, step) | |