File size: 936 Bytes
9e08039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from fastai.basic_train import Learner, LearnerCallback
from fastai.vision.gan import GANLearner


class GANSaveCallback(LearnerCallback):
    """A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""

    def __init__(
        self,
        learn: GANLearner,
        learn_gen: Learner,
        filename: str,
        save_iters: int = 1000,
    ):
        super().__init__(learn)
        self.learn_gen = learn_gen
        self.filename = filename
        self.save_iters = save_iters

    def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
        if iteration == 0:
            return

        if iteration % self.save_iters == 0:
            self._save_gen_learner(iteration=iteration, epoch=epoch)

    def _save_gen_learner(self, iteration: int, epoch: int):
        filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
        self.learn_gen.save(filename)