File size: 1,892 Bytes
cc9dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from ..core import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback

__all__ = ['GeneralScheduler', 'TrainingPhase']

@dataclass
class TrainingPhase():
    "Schedule hyper-parameters for a phase of `length` iterations."
    length:int
    
    def __post_init__(self): self.scheds = dict()
    def schedule_hp(self, name, vals, anneal=None):
        "Adds a schedule for `name` between `vals` using `anneal`."
        self.scheds[name] = Scheduler(vals, self.length, anneal)
        return self

class GeneralScheduler(LearnerCallback):
    "Schedule multiple `TrainingPhase` for a `Learner`."
    def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None):
        super().__init__(learn)
        self.phases,self.start_epoch = phases,start_epoch

    def on_train_begin(self, epoch:int, **kwargs:Any)->None:
        "Initialize the schedulers for training."
        res = {'epoch':self.start_epoch} if self.start_epoch is not None else None
        self.start_epoch = ifnone(self.start_epoch, epoch)
        self.scheds = [p.scheds for p in self.phases]
        self.opt = self.learn.opt
        for k,v in self.scheds[0].items(): 
            v.restart()
            self.opt.set_stat(k, v.start)
        self.idx_s = 0
        return res
    
    def jump_to_epoch(self, epoch:int)->None:
        for _ in range(len(self.learn.data.train_dl) * epoch):
            self.on_batch_end(True)

    def on_batch_end(self, train, **kwargs:Any)->None:
        "Take a step in lr,mom sched, start next stepper when the current one is complete."
        if train:
            if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True}
            sched = self.scheds[self.idx_s]
            for k,v in sched.items(): self.opt.set_stat(k, v.step())
            if list(sched.values())[0].is_done: self.idx_s += 1