""" Support log functions TODO: log model using mlflow.pytorch in parallel / addition to checkpointing """ import numpy as np import h5py import os import argparse import torch import torchvision.utils as vutils import pytorch_lightning as pl class ImgCB(pl.Callback): def __init__(self, **kwargs): parser = ImgCB.add_argparse_args() for action in parser._actions: if action.dest in kwargs: action.default = kwargs[action.dest] args = parser.parse_args([]) self.__dict__.update(vars(args)) @staticmethod def add_argparse_args(parent_parser=None): parser = argparse.ArgumentParser( prog='ImgCB', usage=ImgCB.__doc__, parents=[parent_parser] if parent_parser is not None else [], add_help=False) parser.add_argument('--img_ranges', default=[1300, 1800], nargs='*', help='Scaling range on output image, either pair, or set of pairs') parser.add_argument('--err_ranges', default=[0, 50], nargs='*', help='Scaling range on error images, either pair, or set of pairs') return parser def log_images(self, mfl_logger, y, z, prefix): img_ranges = tuple(self.img_ranges) err_ranges = tuple(self.err_ranges) # for i in range(y.shape[1]): if y.shape[1] > 1: tag = f'_{i}_' if len(self.img_ranges) > 2: img_ranges = tuple(self.img_ranges[2*i, 2*i + 1]) if len(self.err_ranges) > 2: err_ranges = tuple(self.err_ranges[2*i, 2*i + 1]) else: tag = '' mfl_logger.experiment.log_image( mfl_logger.run_id, (np.array(vutils.make_grid( y[:, [i], ...].detach(), normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int), prefix + tag + '_labels.png') mfl_logger.experiment.log_image( mfl_logger.run_id, (np.array(vutils.make_grid( z[:, [i], ...].detach(), normalize=True, value_range=img_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int), prefix + tag + '_outputs.png') mfl_logger.experiment.log_image( mfl_logger.run_id, (np.array(vutils.make_grid( torch.abs(y[:, [i], ...].detach() - z[:, [i], ...].detach()), normalize=True, value_range=err_ranges, nrow=6).cpu())[0, ...] * 255.).astype(np.int), prefix + tag + '_errors.png') def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if batch_idx == 0: with torch.no_grad(): x, y = batch if pl_module.hparams.rand_output_crop: x = x[..., :-pl_module.hparams.rand_output_crop, :] y = y[..., :-pl_module.hparams.rand_output_crop * 2, :] z = pl_module(x.to(pl_module.device)) if isinstance(z, tuple) or isinstance(z, list): z = z[0] self.log_images(pl_module.logger, y.to(pl_module.device), z, 'train_') def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): if batch_idx == 0: with torch.no_grad(): x, y = batch if pl_module.hparams.rand_output_crop: x = x[..., :-pl_module.hparams.rand_output_crop, :] y = y[..., :-pl_module.hparams.rand_output_crop * 2, :] z = pl_module(x.to(pl_module.device)) if isinstance(z, tuple) or isinstance(z, list): z = z[0] self.log_images(pl_module.logger, y.to(pl_module.device), z, 'validate_') class TestLogger(pl.Callback): """ pytorch_lightning Data saving logger for testing output Warning !!! : this function is not multi GPU / multi device safe -- only run on a single gpu / device """ def __init__(self, fname: str = 'output.h5'): self.fname = fname def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): with h5py.File(self.fname, 'a') as f: f[f'batch_{batch_idx:05}'] = outputs.to('cpu').numpy() if len(batch) > 1: f[f'labels_{batch_idx:05}'] = batch[1].to('cpu').numpy()