|
|
|
import tensorflow as tf |
|
d_model = 512 |
|
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
|
def __init__(self, d_model, warmup_steps=4000): |
|
super(CustomSchedule, self).__init__() |
|
|
|
self.d_model = tf.cast(d_model, tf.float32) |
|
self.warmup_steps = tf.cast(warmup_steps, tf.float32) |
|
|
|
def __call__(self, step): |
|
step = tf.cast(step, tf.float32) |
|
|
|
arg1 = tf.math.rsqrt(step) |
|
arg2 = step * (self.warmup_steps ** -1.5) |
|
|
|
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) |
|
|
|
learning_rate = CustomSchedule(d_model) |