File size: 542 Bytes
6142a25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""

@author: Tu Bui @University of Surrey
"""

class SimpleLossWeightScheduler(object):
    def __init__(self, simple_loss_weight_max=10., wait_steps=50000, ramp=100000) -> None:
        self.simple_loss_weight_max = simple_loss_weight_max
        self.wait_steps = wait_steps
        self.ramp = ramp
    
    def __call__(self, step):
        max_weight = self.simple_loss_weight_max - 1
        w = 1 + min(max_weight, max(0., max_weight*(step - self.wait_steps)/self.ramp))
        return w