test / cldm /loss_weight_scheduler.py
Tu Bui
first commit
6142a25
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author: Tu Bui @University of Surrey
"""
class SimpleLossWeightScheduler(object):
def __init__(self, simple_loss_weight_max=10., wait_steps=50000, ramp=100000) -> None:
self.simple_loss_weight_max = simple_loss_weight_max
self.wait_steps = wait_steps
self.ramp = ramp
def __call__(self, step):
max_weight = self.simple_loss_weight_max - 1
w = 1 + min(max_weight, max(0., max_weight*(step - self.wait_steps)/self.ramp))
return w