from fastai.basic_train import Learner, LearnerCallback from fastai.vision.gan import GANLearner class GANSaveCallback(LearnerCallback): """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`.""" def __init__( self, learn: GANLearner, learn_gen: Learner, filename: str, save_iters: int = 1000, ): super().__init__(learn) self.learn_gen = learn_gen self.filename = filename self.save_iters = save_iters def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None: if iteration == 0: return if iteration % self.save_iters == 0: self._save_gen_learner(iteration=iteration, epoch=epoch) def _save_gen_learner(self, iteration: int, epoch: int): filename = '{}_{}_{}'.format(self.filename, epoch, iteration) self.learn_gen.save(filename)