chhetri123's picture
Upload 27 files
340d736 verified
raw
history blame contribute delete
695 Bytes
import tensorflow as tf
d_model = 512
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()
self.d_model = tf.cast(d_model, tf.float32) # Ensure d_model is a float32
self.warmup_steps = tf.cast(warmup_steps, tf.float32) # Ensure warmup_steps is a float32
def __call__(self, step):
step = tf.cast(step, tf.float32) # Ensure step is a float32
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
learning_rate = CustomSchedule(d_model)