Spaces:
Runtime error
Runtime error
import torch | |
import os | |
from argparse import Namespace | |
from pathlib import Path | |
from tqdm.auto import tqdm | |
from torch import autocast | |
from einops import rearrange, reduce, repeat | |
from torch.cuda.amp import GradScaler | |
from torch.nn.functional import binary_cross_entropy_with_logits | |
from trainers.utils import seed_everything, TensorboardLogger | |
from dataloaders.JSRT import build_dataloaders as build_dataloaders_JSRT | |
from models.unet_model import Unet | |
from trainers.train_base_diffusion import save | |
def train(config, model, optimizer, train_dl, val_dl, logger, scaler, step): | |
best_val_loss = float('inf') | |
train_losses = [] | |
if config.dataset == "BRATS2D": | |
train_losses_per_class = [] | |
elif config.shared_weights_over_timesteps and config.experiment == 'datasetDM': | |
train_losses_per_timestep = [] | |
pbar = tqdm(total=config.val_freq, desc='Training') | |
while True: | |
for x, y in train_dl: | |
pbar.update(1) | |
step += 1 | |
if config.shared_weights_over_timesteps and config.experiment == 'datasetDM': | |
y = repeat(y, 'b c h w -> (b step) c h w', step=len(model.steps)) | |
x = x.to(config.device) | |
y = y.to(config.device) | |
optimizer.zero_grad() | |
with autocast(device_type=config.device, enabled=config.mixed_precision): | |
pred = model(x) | |
# cross entropy loss | |
#loss = - ((y * torch.log(torch.sigmoid(pred)) + (1 - y) * torch.log(1 - torch.sigmoid(pred)))).mean() | |
if config.dataset == "BRATS2D": | |
weights = repeat(torch.Tensor(config.loss_weights).to(config.device), 'c -> b c h w', b=y.shape[0], h=y.shape[2], w=y.shape[3]) | |
else: | |
weights = None | |
expanded_loss = reduce(binary_cross_entropy_with_logits(pred, y, weight=weights, reduction='none'), 'b c h w -> b c', 'mean') | |
loss = expanded_loss.mean() | |
scaler.scale(loss).backward() | |
optimizer.step() | |
train_losses.append(loss.item()) | |
if config.dataset == "BRATS2D": | |
loss_per_class = expanded_loss.mean(0) | |
train_losses_per_class.append(loss_per_class.detach().cpu()) | |
pbar.set_description(f'Training loss: {loss.item():.4f} - {loss_per_class[0].item():.4f} - {loss_per_class[1].item():.4f} - {loss_per_class[2].item():.4f} - {loss_per_class[3].item():.4f}') | |
else: | |
pbar.set_description(f'Training loss: {loss.item():.4f}') | |
if config.shared_weights_over_timesteps and config.experiment == 'datasetDM': | |
loss_per_timestep = reduce(expanded_loss, '(b step) c -> step', 'mean', step=len(model.steps)) | |
train_losses_per_timestep.append(loss_per_timestep.detach().cpu()) | |
if step % config.log_freq == 0 or config.debug: | |
avg_train_loss = sum(train_losses) / len(train_losses) | |
print(f'Step {step} - Train loss: {avg_train_loss:.4f}') | |
logger.log({'train/loss': avg_train_loss}, step=step) | |
if config.dataset == "BRATS2D": | |
avg_train_loss_per_class = torch.stack(train_losses_per_class).mean(0) | |
logger.log({'train_loss/0':avg_train_loss_per_class[0].item()}, step=step) | |
logger.log({'train_loss/1':avg_train_loss_per_class[1].item()}, step=step) | |
logger.log({'train_loss/2':avg_train_loss_per_class[2].item()}, step=step) | |
logger.log({'train_loss/3':avg_train_loss_per_class[3].item()}, step=step) | |
if config.shared_weights_over_timesteps and config.experiment == 'datasetDM': | |
avg_train_loss_per_timestep = torch.stack(train_losses_per_timestep).mean(0) | |
for i, model_step in enumerate(model.steps): | |
logger.log({'train_loss/step_' + str(model_step): avg_train_loss_per_timestep[i].item()}, step=step) | |
if step % config.val_freq == 0 or config.debug: | |
val_results = validate(config, model, val_dl) | |
logger.log(val_results, step=step) | |
if val_results['val/loss'] < best_val_loss and not config.debug: | |
print(f'Step {step} - New best validation loss: ' | |
f'{val_results["val/loss"]:.4f}, saving model ' | |
f'in {config.log_dir}') | |
best_val_loss = val_results['val/loss'] | |
save( | |
model, | |
optimizer, | |
config, | |
config.log_dir / 'best_model.pt', | |
step | |
) | |
elif val_results['val/loss'] > best_val_loss * 1.5 and config.early_stop: | |
print(f'Step {step} - Validation loss increased by more than 50%') | |
return model | |
if step >= config.max_steps or config.debug: | |
return model | |
def validate(config, model, val_dl): | |
model.eval() | |
metrics = { | |
'val/loss': [], | |
'val/dice': [], | |
'val/precision': [], | |
'val/recall': [], | |
} | |
for i, (x, y) in tqdm(enumerate(val_dl), desc='Validating'): | |
x = x.to(config.device) | |
with autocast(device_type=config.device, enabled=config.mixed_precision): | |
pred = model(x).detach().cpu() | |
# label predictions | |
if pred.shape[1] == 1: | |
y_hat = torch.sigmoid(pred) > .5 | |
else: | |
y_hat = torch.argmax(pred, dim=1) | |
y_hat = torch.stack([y_hat == i for i in range(y.shape[1])], dim=1) | |
# metrics | |
if config.shared_weights_over_timesteps and config.experiment == 'datasetDM': | |
y = repeat(y, 'b c h w -> (b step) c h w', step=len(model.steps)) | |
metrics['val/dice'].append(dice(y_hat, y)) | |
metrics['val/precision'].append(precision(y_hat, y)) | |
metrics['val/recall'].append(recall(y_hat, y)) | |
metrics['val/loss'].append(binary_cross_entropy_with_logits(pred, y, reduction='none')) | |
if i + 1 == config.max_val_steps or config.debug: | |
break | |
# average metrics | |
avg_loss = torch.cat(metrics['val/loss']).mean() | |
print(f'Validation loss: {avg_loss:.4f}') | |
if y_hat.shape[1] > 1: | |
for i in range(1, y_hat.shape[1]): | |
metrics[f'val_dice/{i}'] = torch.cat(metrics['val/dice'])[:, i].nanmean().item() | |
metrics[f'val_precision/{i}'] = torch.cat(metrics['val/precision'])[:, i].nanmean().item() | |
metrics[f'val_recall/{i}'] = torch.cat(metrics['val/recall'])[:,i].nanmean().item() | |
metrics['val/loss'] = avg_loss.item() | |
metrics['val/dice'] = torch.cat(metrics['val/dice']).nanmean().item() # exclude background + exclude classes not represented (through nanmean) | |
metrics['val/precision'] = torch.cat(metrics['val/precision']).nanmean().item() | |
metrics['val/recall'] = torch.cat(metrics['val/recall']).nanmean().item() | |
model.train() | |
return metrics | |
def dice(x_hat, x): | |
x_n_x_hat = torch.logical_and(x_hat, x) | |
dice = 2 * reduce(x_n_x_hat, 'b c h w -> b c', 'sum') / (reduce(x_hat, 'b c h w -> b c', 'sum') + reduce(x, 'b c h w -> b c', 'sum')) | |
return dice | |
def precision(x_hat, x): | |
TP = reduce(torch.logical_and(x, x_hat), 'b c h w -> b c', 'sum') | |
FP = reduce(torch.logical_and(1 - x, x_hat), 'b c h w -> b c', 'sum') | |
_precision = TP / (TP + FP) | |
return _precision | |
def recall(x_hat, x): | |
TP = reduce(torch.logical_and(x, x_hat), 'b c h w -> b c', 'sum') | |
FN = reduce(torch.logical_and(x, ~x_hat), 'b c h w -> b c', 'sum') | |
_recall = TP / (TP + FN) | |
return _recall | |
def main(config:Namespace) -> None: | |
# adjust logdir to include experiment name | |
os.makedirs(config.log_dir, exist_ok=True) | |
# save config namespace into logdir | |
with open(config.log_dir / 'config.txt', 'w') as f: | |
for k, v in vars(config).items(): | |
if type(v) not in [str, int, float, bool]: | |
f.write(f'{k}: {str(v)}\n') | |
else: | |
f.write(f'{k}: {v}\n') | |
# Random seed | |
seed_everything(config.seed) | |
model = Unet( | |
config.dim, | |
dim_mults=config.dim_mults, | |
channels=config.channels, | |
out_dim=config.out_channels | |
) | |
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay) | |
step = 0 | |
model.to(config.device) | |
model.train() | |
scaler = GradScaler() | |
# Load data | |
if config.dataset == "JSRT": | |
build_dataloaders = build_dataloaders_JSRT | |
else: | |
raise ValueError(f"Unknown dataset: {config.dataset}") | |
dataloaders = build_dataloaders( | |
config.data_dir, | |
config.img_size, | |
config.batch_size, | |
config.num_workers, | |
config.n_labelled_images, | |
) | |
train_dl = dataloaders['train'] | |
val_dl = dataloaders['val'] | |
print(f'Loaded {len(train_dl.dataset)} training and {len(val_dl.dataset)} validation images') | |
# Logger | |
logger = TensorboardLogger(config.log_dir, enabled=not config.debug) | |
train(config, model, optimizer, train_dl, val_dl, logger, scaler, step) |