Spaces:
Runtime error
Runtime error
File size: 9,416 Bytes
a2dba58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import os
import torch
from torch import autocast
from torch.cuda.amp import GradScaler
from tqdm import tqdm
from dataloaders.JSRT import build_dataloaders as build_dataloaders_JSRT
from models.diffusion_model import DiffusionModel
from trainers.utils import (TensorboardLogger, compare_configs, sample_plot_image,
seed_everything)
def train(config, model, optimizer, train_loader, val_loader, logger, scaler, step):
best_val_loss = float('inf')
train_losses = []
if config.joint_training and config.experiment == 'joint_and_cond':
img_losses = []
seg_losses = []
pbar = tqdm(total=config.val_freq, desc='Training')
cond = None
while True:
for x, y in train_loader:
pbar.update(1)
step += 1
if config.experiment == "joint":
x = torch.concat([x,y], dim=1)
elif config.experiment == "conditional":
cond = x # condition on the image
cond = cond.to(config.device)
x = y # predict the segmentation
elif config.experiment == "joint_and_cond":
cond = y.to(config.device)
x = x.to(config.device)
# Forward pass
optimizer.zero_grad()
with autocast(device_type=config.device, enabled=config.mixed_precision):
loss = model.train_step(x, cond)
if config.joint_training and config.experiment == 'joint_and_cond':
img_loss, seg_loss = loss
loss = img_loss + seg_loss
img_losses.append(img_loss.item())
seg_losses.append(seg_loss.item())
scaler.scale(loss).backward()
optimizer.step()
train_losses.append(loss.item())
if config.joint_training and config.experiment == 'joint_and_cond':
pbar.set_description(f'Training loss: {loss.item():.4f} (img: {img_loss.item():.4f}, seg: {seg_loss.item():.4f})')
else:
pbar.set_description(f'Training loss: {loss.item():.4f}')
if step % config.log_freq == 0 or config.debug:
avg_train_loss = sum(train_losses) / len(train_losses)
print(f'Step {step} - Train loss: {avg_train_loss:.4f}')
logger.log({'train/loss': avg_train_loss}, step=step)
if config.joint_training and config.experiment == 'joint_and_cond':
avg_img_loss = sum(img_losses) / len(img_losses)
avg_seg_loss = sum(seg_losses) / len(seg_losses)
logger.log({'train_loss/img': avg_img_loss}, step=step)
logger.log({'train_loss/seg': avg_seg_loss}, step=step)
if step % config.val_freq == 0 or config.debug:
val_results = validate(config, model, val_loader)
logger.log(val_results, step=step)
if val_results['val/loss'] < best_val_loss and not config.debug:
print(f'Step {step} - New best validation loss: '
f'{val_results["val/loss"]:.4f}, saving model '
f'in {config.log_dir}')
best_val_loss = val_results['val/loss']
save(
model,
optimizer,
config,
config.log_dir / 'best_model.pt',
step
)
if step >= config.max_steps or config.debug:
return model
@torch.no_grad()
def validate(config, model, val_loader):
model.eval()
losses = []
if config.joint_training and config.experiment == 'joint_and_cond':
img_losses = []
seg_losses = []
cond=None
for i, (x, y) in tqdm(enumerate(val_loader), desc='Validating'):
if config.experiment == "joint":
x = torch.cat([x,y], dim=1)
elif config.experiment == "conditional":
cond = x # condition on the image
cond = cond.to(config.device)
x = y # predict the segmentation
elif config.experiment == "joint_and_cond":
cond = y.to(config.device)
x = x.to(config.device)
with autocast(device_type=config.device, enabled=config.mixed_precision):
if len(val_loader.dataset) > 1000:
# val set is too large to evaluate for each timestep, take random timesteps
loss = model.train_step(x, cond)
else:
loss = model.val_step(x, cond, t_steps=config.val_steps)
if config.joint_training and config.experiment == 'joint_and_cond':
img_loss, seg_loss = loss
loss = img_loss + seg_loss
img_losses.append(img_loss.item())
seg_losses.append(seg_loss.item())
losses.append(loss.item())
if i + 1 == config.max_val_steps or config.debug:
break
avg_loss = sum(losses) / len(losses)
if config.joint_training and config.experiment == 'joint_and_cond':
avg_img_loss = sum(img_losses) / len(img_losses)
avg_seg_loss = sum(seg_losses) / len(seg_losses)
print(f'Validation loss: {avg_loss:.4f} (img: {avg_img_loss:.4f}, seg: {avg_seg_loss:.4f})')
else:
print(f'Validation loss: {avg_loss:.4f}')
# Build visualisations
# select images for visualisation
if config.experiment == "conditional":
# note that there is no randomness on how the images are selected
# if there are enough images on the last validation batch, then we keep those
# otherwise, we take the first images from a new validation epoch batch
config.n_sampled_imgs = config.n_sampled_imgs if not config.debug else 1
N_missing = config.n_sampled_imgs - cond.shape[0]
iter_val_loader = iter(val_loader) # new validation epoch
while N_missing > 0: # are we missing images?
new_cond, _ = next(iter_val_loader)
cond = torch.cat([cond, new_cond], dim=0)
N_missing = config.n_sampled_imgs - cond.shape[0]
del iter_val_loader
if N_missing < 0: # do we have too many images?
cond = cond[:config.n_sampled_imgs]
with autocast(device_type=config.device, enabled=config.mixed_precision):
sampled_imgs = sample_plot_image(
model,
config.timesteps,
config.img_size,
config.n_sampled_imgs, # this defaults to 1 if in debug mode
config.channels if config.experiment != "joint_and_cond" else config.channels + config.out_channels, # out_channels stands for channels in segmentation
cond=cond,
normalized=config.normalize,
)
model.train()
return {
'val/loss': avg_loss,
'val/sampled images': sampled_imgs
}
def save(model, optimizer, config, path, step):
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'config': config,
'step': step
}, path)
def load(new_config, path):
checkpoint = torch.load(path, map_location=torch.device(new_config.device))
old_config = checkpoint['config']
compare_configs(old_config, new_config)
model = DiffusionModel(old_config)
model.to(new_config.device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer = torch.optim.Adam(model.parameters(), lr=new_config.lr)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
step = checkpoint['step']
return model, optimizer, step
def main(config):
# adjust logdir to include experiment name
os.makedirs(config.log_dir, exist_ok=True)
# save config namespace into logdir
with open(config.log_dir / 'config.txt', 'w') as f:
for k, v in vars(config).items():
if type(v) not in [str, int, float, bool]:
f.write(f'{k}: {str(v)}\n')
else:
f.write(f'{k}: {v}\n')
# Random seed
seed_everything(config.seed)
# Init model and optimizer
if config.resume_path is not None:
print('Loading model from', config.resume_path)
diffusion_model, optimizer, step = load(config, config.resume_path)
else:
if config.experiment in ["img_only", "joint", "conditional"]:
diffusion_model = DiffusionModel(config)
else:
raise ValueError(f"Unknown experiment: {config.experiment}")
optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=config.lr, weight_decay=config.weight_decay) # , betas=config.adam_betas)
step = 0
diffusion_model.to(config.device)
diffusion_model.train()
scaler = GradScaler()
# Load data
if config.dataset == "JSRT":
build_dataloaders = build_dataloaders_JSRT
else:
raise ValueError(f"Unknown dataset: {config.dataset}")
dataloaders = build_dataloaders(
config.data_dir,
config.img_size,
config.batch_size,
config.num_workers,
config.n_labelled_images,
)
train_dl = dataloaders['train']
val_dl = dataloaders['val']
# Logger
logger = TensorboardLogger(config.log_dir, enabled=not config.debug)
train(config, diffusion_model, optimizer, train_dl, val_dl, logger, scaler, step)
|