Spaces:
Runtime error
Runtime error
from fastai.basic_train import Learner, LearnerCallback | |
from fastai.vision.gan import GANLearner | |
class GANSaveCallback(LearnerCallback): | |
"""A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`.""" | |
def __init__( | |
self, | |
learn: GANLearner, | |
learn_gen: Learner, | |
filename: str, | |
save_iters: int = 1000, | |
): | |
super().__init__(learn) | |
self.learn_gen = learn_gen | |
self.filename = filename | |
self.save_iters = save_iters | |
def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None: | |
if iteration == 0: | |
return | |
if iteration % self.save_iters == 0: | |
self._save_gen_learner(iteration=iteration, epoch=epoch) | |
def _save_gen_learner(self, iteration: int, epoch: int): | |
filename = '{}_{}_{}'.format(self.filename, epoch, iteration) | |
self.learn_gen.save(filename) | |