shiftvit / utils /lr_schedule.py
shivi's picture
Committing all app files
cacf2d0
import tensorflow as tf
from tensorflow import keras
import numpy as np
"""
Below code is taken from the [ShiftViT keras example](https://keras.io/examples/vision/shiftvit/) by Aritra Roy Gosthipaty & Ritwik Raha
"""
# Some code is taken from:
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses a warmup cosine decay schedule."""
def __init__(self, lr_start, lr_max, warmup_steps, total_steps):
"""
Args:
lr_start: The initial learning rate
lr_max: The maximum learning rate to which lr should increase to in
the warmup steps
warmup_steps: The number of steps for which the model warms up
total_steps: The total number of steps for the model training
"""
super().__init__()
self.lr_start = lr_start
self.lr_max = lr_max
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.pi = tf.constant(np.pi)
def __call__(self, step):
# Check whether the total number of steps is larger than the warmup
# steps. If not, then throw a value error.
if self.total_steps < self.warmup_steps:
raise ValueError(
f"Total number of steps {self.total_steps} must be"
+ f"larger or equal to warmup steps {self.warmup_steps}."
)
# `cos_annealed_lr` is a graph that increases to 1 from the initial
# step to the warmup step. After that this graph decays to -1 at the
# final step mark.
cos_annealed_lr = tf.cos(
self.pi
* (tf.cast(step, tf.float32) - self.warmup_steps)
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32)
)
# Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes
# from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0
# to 1. With the normalized graph we scale it with `lr_max` such that
# it goes from 0 to `lr_max`
learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr)
# Check whether warmup_steps is more than 0.
if self.warmup_steps > 0:
# Check whether lr_max is larger that lr_start. If not, throw a value
# error.
if self.lr_max < self.lr_start:
raise ValueError(
f"lr_start {self.lr_start} must be smaller or"
+ f"equal to lr_max {self.lr_max}."
)
# Calculate the slope with which the learning rate should increase
# in the warumup schedule. The formula for slope is m = ((b-a)/steps)
slope = (self.lr_max - self.lr_start) / self.warmup_steps
# With the formula for a straight line (y = mx+c) build the warmup
# schedule
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start
# When the current step is lesser that warmup steps, get the line
# graph. When the current step is greater than the warmup steps, get
# the scaled cos graph.
learning_rate = tf.where(
step < self.warmup_steps, warmup_rate, learning_rate
)
# When the current step is more that the total steps, return 0 else return
# the calculated graph.
return tf.where(
step > self.total_steps, 0.0, learning_rate, name="learning_rate"
)
def get_config(self):
config = {
"lr_start": self.lr_start,
"lr_max": self.lr_max,
"total_steps": self.total_steps,
'warmup_steps': self.warmup_steps
}
return config