import os import pytorch_lightning as pl import matplotlib.pyplot as plt import csv import torch from monai.transforms import AsDiscrete, Activations, Compose, EnsureType from models.SegTranVAE.SegTranVAE import SegTransVAE from loss.loss import Loss_VAE, DiceScore from monai.losses import DiceLoss import pytorch_lightning as pl from monai.inferers import sliding_window_inference class BRATS(pl.LightningModule): def __init__(self,train_loader,val_loader,test_loader, use_VAE = True, lr = 1e-4 ): super().__init__() self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.use_vae = use_VAE self.lr = lr self.model = SegTransVAE((128, 128, 128), 8, 4, 3, 768, 8, 4, 3072, in_channels_vae=128, use_VAE = use_VAE) self.loss_vae = Loss_VAE() self.dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True) self.post_trans_images = Compose( [EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True), ] ) self.best_val_dice = 0 self.training_step_outputs = [] self.val_step_loss = [] self.val_step_dice = [] self.val_step_dice_tc = [] self.val_step_dice_wt = [] self.val_step_dice_et = [] self.test_step_loss = [] self.test_step_dice = [] self.test_step_dice_tc = [] self.test_step_dice_wt = [] self.test_step_dice_et = [] def forward(self, x, is_validation = True): return self.model(x, is_validation) def training_step(self, batch, batch_index): inputs, labels = (batch['image'], batch['label']) if not self.use_vae: outputs = self.forward(inputs, is_validation=False) loss = self.dice_loss(outputs, labels) else: outputs, recon_batch, mu, sigma = self.forward(inputs, is_validation=False) vae_loss = self.loss_vae(recon_batch, inputs, mu, sigma) dice_loss = self.dice_loss(outputs, labels) loss = dice_loss + 1/(4 * 128 * 128 * 128) * vae_loss self.training_step_outputs.append(loss) self.log('train/vae_loss', vae_loss) self.log('train/dice_loss', dice_loss) if batch_index == 10: tensorboard = self.logger.experiment fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 5)) ax[0].imshow(inputs.detach().cpu()[0][0][:, :, 80], cmap='gray') ax[0].set_title("Input") ax[1].imshow(recon_batch.detach().cpu().float()[0][0][:,:, 80], cmap='gray') ax[1].set_title("Reconstruction") ax[2].imshow(labels.detach().cpu().float()[0][0][:,:, 80], cmap='gray') ax[2].set_title("Labels TC") ax[3].imshow(outputs.sigmoid().detach().cpu().float()[0][0][:,:, 80], cmap='gray') ax[3].set_title("TC") ax[4].imshow(labels.detach().cpu().float()[0][2][:,:, 80], cmap='gray') ax[4].set_title("Labels ET") ax[5].imshow(outputs.sigmoid().detach().cpu().float()[0][2][:,:, 80], cmap='gray') ax[5].set_title("ET") tensorboard.add_figure('train_visualize', fig, self.current_epoch) self.log('train/loss', loss) return loss def on_train_epoch_end(self): ## F1 Macro all epoch saving outputs and target per batch # free up the memory # --> HERE STEP 3 <-- epoch_average = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_average", epoch_average) self.training_step_outputs.clear() # free memory def validation_step(self, batch, batch_index): inputs, labels = (batch['image'], batch['label']) roi_size = (128, 128, 128) sw_batch_size = 1 outputs = sliding_window_inference( inputs, roi_size, sw_batch_size, self.model, overlap = 0.5) loss = self.dice_loss(outputs, labels) val_outputs = self.post_trans_images(outputs) metric_tc = DiceScore(y_pred=val_outputs[:, 0:1], y=labels[:, 0:1], include_background = True) metric_wt = DiceScore(y_pred=val_outputs[:, 1:2], y=labels[:, 1:2], include_background = True) metric_et = DiceScore(y_pred=val_outputs[:, 2:3], y=labels[:, 2:3], include_background = True) mean_val_dice = (metric_tc + metric_wt + metric_et)/3 self.val_step_loss.append(loss) self.val_step_dice.append(mean_val_dice) self.val_step_dice_tc.append(metric_tc) self.val_step_dice_wt.append(metric_wt) self.val_step_dice_et.append(metric_et) return {'val_loss': loss, 'val_mean_dice': mean_val_dice, 'val_dice_tc': metric_tc, 'val_dice_wt': metric_wt, 'val_dice_et': metric_et} def on_validation_epoch_end(self): loss = torch.stack(self.val_step_loss).mean() mean_val_dice = torch.stack(self.val_step_dice).mean() metric_tc = torch.stack(self.val_step_dice_tc).mean() metric_wt = torch.stack(self.val_step_dice_wt).mean() metric_et = torch.stack(self.val_step_dice_et).mean() self.log('val/Loss', loss) self.log('val/MeanDiceScore', mean_val_dice) self.log('val/DiceTC', metric_tc) self.log('val/DiceWT', metric_wt) self.log('val/DiceET', metric_et) os.makedirs(self.logger.log_dir, exist_ok=True) if self.current_epoch == 0: with open('{}/metric_log.csv'.format(self.logger.log_dir), 'w') as f: writer = csv.writer(f) writer.writerow(['Epoch', 'Mean Dice Score', 'Dice TC', 'Dice WT', 'Dice ET']) with open('{}/metric_log.csv'.format(self.logger.log_dir), 'a') as f: writer = csv.writer(f) writer.writerow([self.current_epoch, mean_val_dice.item(), metric_tc.item(), metric_wt.item(), metric_et.item()]) if mean_val_dice > self.best_val_dice: self.best_val_dice = mean_val_dice self.best_val_epoch = self.current_epoch print( f"\n Current epoch: {self.current_epoch} Current mean dice: {mean_val_dice:.4f}" f" tc: {metric_tc:.4f} wt: {metric_wt:.4f} et: {metric_et:.4f}" f"\n Best mean dice: {self.best_val_dice}" f" at epoch: {self.best_val_epoch}" ) self.val_step_loss.clear() self.val_step_dice.clear() self.val_step_dice_tc.clear() self.val_step_dice_wt.clear() self.val_step_dice_et.clear() return {'val_MeanDiceScore': mean_val_dice} def test_step(self, batch, batch_index): inputs, labels = (batch['image'], batch['label']) roi_size = (128, 128, 128) sw_batch_size = 1 test_outputs = sliding_window_inference( inputs, roi_size, sw_batch_size, self.forward, overlap = 0.5) loss = self.dice_loss(test_outputs, labels) test_outputs = self.post_trans_images(test_outputs) metric_tc = DiceScore(y_pred=test_outputs[:, 0:1], y=labels[:, 0:1], include_background = True) metric_wt = DiceScore(y_pred=test_outputs[:, 1:2], y=labels[:, 1:2], include_background = True) metric_et = DiceScore(y_pred=test_outputs[:, 2:3], y=labels[:, 2:3], include_background = True) mean_test_dice = (metric_tc + metric_wt + metric_et)/3 self.test_step_loss.append(loss) self.test_step_dice.append(mean_test_dice) self.test_step_dice_tc.append(metric_tc) self.test_step_dice_wt.append(metric_wt) self.test_step_dice_et.append(metric_et) return {'test_loss': loss, 'test_mean_dice': mean_test_dice, 'test_dice_tc': metric_tc, 'test_dice_wt': metric_wt, 'test_dice_et': metric_et} def test_epoch_end(self): loss = torch.stack(self.test_step_loss).mean() mean_test_dice = torch.stack(self.test_step_dice).mean() metric_tc = torch.stack(self.test_step_dice_tc).mean() metric_wt = torch.stack(self.test_step_dice_wt).mean() metric_et = torch.stack(self.test_step_dice_et).mean() self.log('test/Loss', loss) self.log('test/MeanDiceScore', mean_test_dice) self.log('test/DiceTC', metric_tc) self.log('test/DiceWT', metric_wt) self.log('test/DiceET', metric_et) with open('{}/test_log.csv'.format(self.logger.log_dir), 'w') as f: writer = csv.writer(f) writer.writerow(["Mean Test Dice", "Dice TC", "Dice WT", "Dice ET"]) writer.writerow([mean_test_dice, metric_tc, metric_wt, metric_et]) self.test_step_loss.clear() self.test_step_dice.clear() self.test_step_dice_tc.clear() self.test_step_dice_wt.clear() self.test_step_dice_et.clear() return {'test_MeanDiceScore': mean_test_dice} def configure_optimizers(self): optimizer = torch.optim.Adam( self.model.parameters(), self.lr, weight_decay=1e-5, amsgrad=True ) # optimizer = AdaBelief(self.model.parameters(), # lr=self.lr, eps=1e-16, # betas=(0.9,0.999), weight_decouple = True, # rectify = False) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 200) return [optimizer], [scheduler] def train_dataloader(self): return self.train_loader def val_dataloader(self): return self.val_loader def test_dataloader(self): return self.test_loader