|
|
|
|
|
|
|
from utility.schedule.linear_lr import LinearLr |
|
|
|
|
|
def multi_scheduler_wrapper(optimizer, args, steps_per_epoch): |
|
n_layers = (len(optimizer.param_groups) - 2) // 2 |
|
|
|
return MultiScheduler( |
|
[ |
|
LinearLr(optimizer.param_groups[i], args.encoder_learning_rate * (args.layerwise_lr_decay ** i), args.epochs * steps_per_epoch, False, args.lr_decay_multiplier) |
|
for i in range(n_layers) |
|
] |
|
+ |
|
[ |
|
LinearLr(optimizer.param_groups[n_layers + i], args.encoder_learning_rate * (args.layerwise_lr_decay ** i), args.epochs * steps_per_epoch, False, args.lr_decay_multiplier) |
|
for i in range(n_layers) |
|
] |
|
+ |
|
[ |
|
LinearLr(optimizer.param_groups[-2], args.decoder_learning_rate, args.epochs * steps_per_epoch, False, args.lr_decay_multiplier), |
|
LinearLr(optimizer.param_groups[-1], args.decoder_learning_rate, args.epochs * steps_per_epoch, False, args.lr_decay_multiplier) |
|
] |
|
) |
|
|
|
|
|
class MultiScheduler: |
|
def __init__(self, schedulers): |
|
self.schedulers = schedulers |
|
|
|
def __call__(self, epoch): |
|
for scheduler in self.schedulers: |
|
scheduler(epoch) |
|
|
|
def lr(self) -> float: |
|
return [scheduler.lr() for scheduler in self.schedulers] |
|
|