import numpy as np import tensorflow as tf from tensorflow import keras class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): """A LearningRateSchedule that uses a warmup cosine decay schedule.""" def __init__(self, lr_start, lr_max, warmup_steps, total_steps): """ Args: lr_start: The initial learning rate lr_max: The maximum learning rate to which lr should increase to in the warmup steps warmup_steps: The number of steps for which the model warms up total_steps: The total number of steps for the model training """ super().__init__() self.lr_start = lr_start self.lr_max = lr_max self.warmup_steps = warmup_steps self.total_steps = total_steps self.pi = tf.constant(np.pi) def __call__(self, step): # Check whether the total number of steps is larger than the warmup # steps. If not, then throw a value error. if self.total_steps < self.warmup_steps: raise ValueError( f"Total number of steps {self.total_steps} must be" + f"larger or equal to warmup steps {self.warmup_steps}." ) # `cos_annealed_lr` is a graph that increases to 1 from the initial # step to the warmup step. After that this graph decays to -1 at the # final step mark. cos_annealed_lr = tf.cos( self.pi * (tf.cast(step, tf.float32) - self.warmup_steps) / tf.cast(self.total_steps - self.warmup_steps, tf.float32) ) # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes # from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0 # to 1. With the normalized graph we scale it with `lr_max` such that # it goes from 0 to `lr_max` learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr) # Check whether warmup_steps is more than 0. if self.warmup_steps > 0: # Check whether lr_max is larger that lr_start. If not, throw a value # error. if self.lr_max < self.lr_start: raise ValueError( f"lr_start {self.lr_start} must be smaller or" + f"equal to lr_max {self.lr_max}." ) # Calculate the slope with which the learning rate should increase # in the warumup schedule. The formula for slope is m = ((b-a)/steps) slope = (self.lr_max - self.lr_start) / self.warmup_steps # With the formula for a straight line (y = mx+c) build the warmup # schedule warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start # When the current step is lesser that warmup steps, get the line # graph. When the current step is greater than the warmup steps, get # the scaled cos graph. learning_rate = tf.where( step < self.warmup_steps, warmup_rate, learning_rate ) # When the current step is more that the total steps, return 0 else return # the calculated graph. return tf.where( step > self.total_steps, 0.0, learning_rate, name="learning_rate" )