# import imageio import logging import os from collections import defaultdict from timeit import default_timer import torch from torch.nn import functional as F from tqdm import trange from disvae.utils.modelIO import save_model TRAIN_LOSSES_LOGFILE = "train_losses.log" class Trainer: """ Class to handle training of model. Parameters ---------- model: disvae.vae.VAE optimizer: torch.optim.Optimizer loss_f: disvae.models.BaseLoss Loss function. device: torch.device, optional Device on which to run the code. logger: logging.Logger, optional Logger. save_dir : str, optional Directory for saving logs. gif_visualizer : viz.Visualizer, optional Gif Visualizer that should return samples at every epochs. is_progress_bar: bool, optional Whether to use a progress bar for training. """ def __init__( self, model, optimizer, loss_f, device=torch.device("cpu"), logger=logging.getLogger(__name__), save_dir="results", gif_visualizer=None, is_progress_bar=True, ): self.device = device self.model = model.to(self.device) self.loss_f = loss_f self.optimizer = optimizer self.save_dir = save_dir self.is_progress_bar = is_progress_bar self.logger = logger self.losses_logger = LossesLogger( os.path.join(self.save_dir, TRAIN_LOSSES_LOGFILE) ) self.gif_visualizer = gif_visualizer self.logger.info("Training Device: {}".format(self.device)) def __call__(self, data_loader, epochs=10, checkpoint_every=10): """ Trains the model. Parameters ---------- data_loader: torch.utils.data.DataLoader epochs: int, optional Number of epochs to train the model for. checkpoint_every: int, optional Save a checkpoint of the trained model every n epoch. """ start = default_timer() self.model.train() for epoch in range(epochs): storer = defaultdict(list) mean_epoch_loss = self._train_epoch(data_loader, storer, epoch) self.logger.info( "Epoch: {} Average loss per image: {:.2f}".format( epoch + 1, mean_epoch_loss ) ) self.losses_logger.log(epoch, storer) if self.gif_visualizer is not None: self.gif_visualizer() if epoch % checkpoint_every == 0: save_model( self.model, self.save_dir, filename="model-{}.pt".format(epoch) ) if self.gif_visualizer is not None: self.gif_visualizer.save_reset() self.model.eval() delta_time = (default_timer() - start) / 60 self.logger.info("Finished training after {:.1f} min.".format(delta_time)) def _train_epoch(self, data_loader, storer, epoch): """ Trains the model for one epoch. Parameters ---------- data_loader: torch.utils.data.DataLoader storer: dict Dictionary in which to store important variables for vizualisation. epoch: int Epoch number Return ------ mean_epoch_loss: float Mean loss per image """ epoch_loss = 0.0 kwargs = dict( desc="Epoch {}".format(epoch + 1), leave=False, disable=not self.is_progress_bar, ) with trange(len(data_loader), **kwargs) as t: for _, (data, _) in enumerate(data_loader): iter_loss = self._train_iteration(data, storer) epoch_loss += iter_loss t.set_postfix(loss=iter_loss) t.update() mean_epoch_loss = epoch_loss / len(data_loader) return mean_epoch_loss def _train_iteration(self, data, storer): """ Trains the model for one iteration on a batch of data. Parameters ---------- data: torch.Tensor A batch of data. Shape : (batch_size, channel, height, width). storer: dict Dictionary in which to store important variables for vizualisation. """ batch_size, channel, height, width = data.size() data = data.to(self.device) try: recon_batch, latent_dist, latent_sample = self.model(data) loss = self.loss_f( data, recon_batch, latent_dist, self.model.training, storer, latent_sample=latent_sample, ) self.optimizer.zero_grad() loss.backward() self.optimizer.step() except ValueError: # for losses that use multiple optimizers (e.g. Factor) loss = self.loss_f.call_optimize(data, self.model, self.optimizer, storer) return loss.item() class LossesLogger(object): """Class definition for objects to write data to log files in a form which is then easy to be plotted. """ def __init__(self, file_path_name): """Create a logger to store information for plotting.""" if os.path.isfile(file_path_name): os.remove(file_path_name) self.logger = logging.getLogger("losses_logger") self.logger.setLevel(1) # always store file_handler = logging.FileHandler(file_path_name) file_handler.setLevel(1) self.logger.addHandler(file_handler) header = ",".join(["Epoch", "Loss", "Value"]) self.logger.debug(header) def log(self, epoch, losses_storer): """Write to the log file""" for k, v in losses_storer.items(): log_string = ",".join(str(item) for item in [epoch, k, mean(v)]) self.logger.debug(log_string) # HELPERS def mean(l): """Compute the mean of a list""" return sum(l) / len(l)