TEDM-demo / trainers /finetune_glob_cl.py
anonymous
first commit without models
a2dba58
raw
history blame
7.95 kB
import argparse
import os
from pathlib import Path
import torch
from torch import autocast, Tensor
from torch.nn.functional import binary_cross_entropy_with_logits
from torch.cuda.amp import GradScaler
from tqdm import tqdm
from config import parser
from einops import rearrange, reduce, repeat
from dataloaders.JSRT import build_dataloaders
from models.unet_model import Unet
from trainers.train_baseline import validate, save
from trainers.utils import (TensorboardLogger, compare_configs, seed_everything, crop_batch)
def train(config, model, optimizer, train_dl, val_dl, logger, scaler, step):
best_val_loss = float('inf')
train_losses = []
if config.dataset == "BRATS2D":
train_losses_per_class = []
elif config.shared_weights_over_timesteps and config.experiment == 'datasetDM':
train_losses_per_timestep = []
pbar = tqdm(total=config.val_freq, desc='Training')
while True:
for x, y in train_dl:
if config.shared_weights_over_timesteps and config.experiment == 'datasetDM':
y = repeat(y, 'b c h w -> (b step) c h w', step=len(model.steps))
if config.augment_at_finetuning:
x, y = crop_batch([x, y], config.img_size, config.batch_size)
brightness = torch.rand((config.batch_size, 1, 1, 1), device=x.device)*.6 - .3 # random brightness adjustment between [-.3, .3]
contrast = torch.rand((config.batch_size, 1, 1, 1), device=x.device)*.6 + .7 # random contrast adjustment between [.7, 1.3]
x = (x + brightness) * contrast # apply brightness and contrast
x = x.to(config.device)
y = y.to(config.device)
optimizer.zero_grad()
with autocast(device_type=config.device, enabled=config.mixed_precision):
pred = model(x)
# cross entropy loss
#loss = - ((y * torch.log(torch.sigmoid(pred)) + (1 - y) * torch.log(1 - torch.sigmoid(pred)))).mean()
if config.dataset == "BRATS2D":
weights = repeat(torch.Tensor(config.loss_weights).to(config.device), 'c -> b c h w', b=y.shape[0], h=y.shape[2], w=y.shape[3])
else:
weights = None
expanded_loss = reduce(binary_cross_entropy_with_logits(pred, y, weight=weights, reduction='none'), 'b c h w -> b c', 'mean')
loss = expanded_loss.mean()
scaler.scale(loss).backward()
optimizer.step()
train_losses.append(loss.item())
if config.dataset == "BRATS2D":
loss_per_class = expanded_loss.mean(0)
train_losses_per_class.append(loss_per_class.detach().cpu())
pbar.set_description(f'Training loss: {loss.item():.4f} - {loss_per_class[0].item():.4f} - {loss_per_class[1].item():.4f} - {loss_per_class[2].item():.4f} - {loss_per_class[3].item():.4f}')
else:
pbar.set_description(f'Training loss: {loss.item():.4f}')
pbar.update(1)
step += 1
if config.unfreeze_weights_at_step == step:
for name, param in model.named_parameters():
if name.startswith('downs') or name.startswith('init_conv') or name.startswith('mid_'):
param.requires_grad = True
if step % config.log_freq == 0 or config.debug:
avg_train_loss = sum(train_losses) / len(train_losses)
print(f'Step {step} - Train loss: {avg_train_loss:.4f}')
logger.log({'train/loss': avg_train_loss}, step=step)
if config.dataset == "BRATS2D":
avg_train_loss_per_class = torch.stack(train_losses_per_class).mean(0)
logger.log({'train_loss/0':avg_train_loss_per_class[0].item()}, step=step)
logger.log({'train_loss/1':avg_train_loss_per_class[1].item()}, step=step)
logger.log({'train_loss/2':avg_train_loss_per_class[2].item()}, step=step)
logger.log({'train_loss/3':avg_train_loss_per_class[3].item()}, step=step)
if config.shared_weights_over_timesteps and config.experiment == 'datasetDM':
avg_train_loss_per_timestep = torch.stack(train_losses_per_timestep).mean(0)
for i, model_step in enumerate(model.steps):
logger.log({'train_loss/step_' + str(model_step): avg_train_loss_per_timestep[i].item()}, step=step)
if step % config.val_freq == 0 or config.debug:
val_results = validate(config, model, val_dl)
logger.log(val_results, step=step)
if val_results['val/loss'] < best_val_loss and not config.debug:
print(f'Step {step} - New best validation loss: '
f'{val_results["val/loss"]:.4f}, saving model '
f'in {config.log_dir}')
best_val_loss = val_results['val/loss']
save(
model,
optimizer,
config,
config.log_dir / 'best_model.pt',
step
)
elif val_results['val/loss'] > best_val_loss * 1.5 and config.early_stop:
print(f'Step {step} - Validation loss increased by more than 50%')
return model
if step >= config.max_steps or config.debug:
return model
def load(config, path):
raise NotImplementedError
def main(config):
os.makedirs(config.log_dir, exist_ok=True)
# save config namespace into logdir
with open(config.log_dir / 'config.txt', 'w') as f:
for k, v in vars(config).items():
if type(v) not in [str, int, float, bool]:
f.write(f'{k}: {str(v)}\n')
else:
f.write(f'{k}: {v}\n')
# Random seed
seed_everything(config.seed)
# Init model and optimizer
if config.resume_path is not None:
print('Loading model from', config.resume_path)
model, optimizer, step = load(config, config.resume_path)
else:
model = Unet(
img_size=config.img_size,
dim=config.dim,
dim_mults=config.dim_mults,
channels=config.channels,
out_dim=config.out_channels)
state_dict = torch.load(config.global_model_path, map_location='cpu')['model_state_dict']
out = model.load_state_dict(state_dict=state_dict, strict=False)
print("Loaded state dict. \n\tMissing keys: {}\n\tUnexpected keys: {}".format(out.missing_keys, out.unexpected_keys))
print('Note that although the state dict of the decoder is loaded, its values are random.')
if config.unfreeze_weights_at_step !=0:
for name, param in model.named_parameters():
if name.startswith('downs') or name.startswith('init_conv') or name.startswith('mid_'):
param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) # , betas=config.adam_betas)
step = 0
model.to(config.device)
model.train()
scaler = GradScaler()
# Load data
dataloaders = build_dataloaders(
config.data_dir,
config.img_size,
config.batch_size,
config.num_workers,
n_labelled_images=config.n_labelled_images,
)
train_dl = dataloaders['train']
val_dl = dataloaders['val']
print('Train dataset size:', len(train_dl.dataset))
print('Validation dataset size:', len(val_dl.dataset))
# Logger
logger = TensorboardLogger(config.log_dir, enabled=not config.debug)
train(config, model, optimizer, train_dl, val_dl, logger, scaler, step)