Spaces:
Runtime error
Runtime error
File size: 936 Bytes
9e08039 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from fastai.basic_train import Learner, LearnerCallback
from fastai.vision.gan import GANLearner
class GANSaveCallback(LearnerCallback):
"""A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
def __init__(
self,
learn: GANLearner,
learn_gen: Learner,
filename: str,
save_iters: int = 1000,
):
super().__init__(learn)
self.learn_gen = learn_gen
self.filename = filename
self.save_iters = save_iters
def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
if iteration == 0:
return
if iteration % self.save_iters == 0:
self._save_gen_learner(iteration=iteration, epoch=epoch)
def _save_gen_learner(self, iteration: int, epoch: int):
filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
self.learn_gen.save(filename)
|