Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow import keras | |
import numpy as np | |
""" | |
Below code is taken from the [ShiftViT keras example](https://keras.io/examples/vision/shiftvit/) by Aritra Roy Gosthipaty & Ritwik Raha | |
""" | |
# Some code is taken from: | |
# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2. | |
class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule): | |
"""A LearningRateSchedule that uses a warmup cosine decay schedule.""" | |
def __init__(self, lr_start, lr_max, warmup_steps, total_steps): | |
""" | |
Args: | |
lr_start: The initial learning rate | |
lr_max: The maximum learning rate to which lr should increase to in | |
the warmup steps | |
warmup_steps: The number of steps for which the model warms up | |
total_steps: The total number of steps for the model training | |
""" | |
super().__init__() | |
self.lr_start = lr_start | |
self.lr_max = lr_max | |
self.warmup_steps = warmup_steps | |
self.total_steps = total_steps | |
self.pi = tf.constant(np.pi) | |
def __call__(self, step): | |
# Check whether the total number of steps is larger than the warmup | |
# steps. If not, then throw a value error. | |
if self.total_steps < self.warmup_steps: | |
raise ValueError( | |
f"Total number of steps {self.total_steps} must be" | |
+ f"larger or equal to warmup steps {self.warmup_steps}." | |
) | |
# `cos_annealed_lr` is a graph that increases to 1 from the initial | |
# step to the warmup step. After that this graph decays to -1 at the | |
# final step mark. | |
cos_annealed_lr = tf.cos( | |
self.pi | |
* (tf.cast(step, tf.float32) - self.warmup_steps) | |
/ tf.cast(self.total_steps - self.warmup_steps, tf.float32) | |
) | |
# Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes | |
# from 0 to 2. Normalize the graph with 0.5 so that now it goes from 0 | |
# to 1. With the normalized graph we scale it with `lr_max` such that | |
# it goes from 0 to `lr_max` | |
learning_rate = 0.5 * self.lr_max * (1 + cos_annealed_lr) | |
# Check whether warmup_steps is more than 0. | |
if self.warmup_steps > 0: | |
# Check whether lr_max is larger that lr_start. If not, throw a value | |
# error. | |
if self.lr_max < self.lr_start: | |
raise ValueError( | |
f"lr_start {self.lr_start} must be smaller or" | |
+ f"equal to lr_max {self.lr_max}." | |
) | |
# Calculate the slope with which the learning rate should increase | |
# in the warumup schedule. The formula for slope is m = ((b-a)/steps) | |
slope = (self.lr_max - self.lr_start) / self.warmup_steps | |
# With the formula for a straight line (y = mx+c) build the warmup | |
# schedule | |
warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start | |
# When the current step is lesser that warmup steps, get the line | |
# graph. When the current step is greater than the warmup steps, get | |
# the scaled cos graph. | |
learning_rate = tf.where( | |
step < self.warmup_steps, warmup_rate, learning_rate | |
) | |
# When the current step is more that the total steps, return 0 else return | |
# the calculated graph. | |
return tf.where( | |
step > self.total_steps, 0.0, learning_rate, name="learning_rate" | |
) | |
def get_config(self): | |
config = { | |
"lr_start": self.lr_start, | |
"lr_max": self.lr_max, | |
"total_steps": self.total_steps, | |
'warmup_steps': self.warmup_steps | |
} | |
return config |