NoteDance commited on
Commit
c462c28
1 Parent(s): 69a6acd

Update Llama2.py

Browse files
Files changed (1) hide show
  1. Llama2.py +7 -2
Llama2.py CHANGED
@@ -24,10 +24,15 @@ class ModelArgs:
24
  weight_decay: float = 0.1
25
 
26
 
27
- class RMSNorm:
28
  def __init__(self, dim: int, eps: float):
29
  self.eps = eps
30
- self.weight = tf.Variable(tf.ones((dim,)))
 
 
 
 
 
31
 
32
  def _norm(self, x):
33
  return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), -1, keepdims=True) + self.eps)
 
24
  weight_decay: float = 0.1
25
 
26
 
27
+ class RMSNorm(tf.keras.layers.Layer):
28
  def __init__(self, dim: int, eps: float):
29
  self.eps = eps
30
+ self.weight = self.add_weight(
31
+ name='weight',
32
+ shape=(self.dim,),
33
+ initializer=tf.keras.initializers.Ones(),
34
+ trainable=True
35
+ )
36
 
37
  def _norm(self, x):
38
  return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), -1, keepdims=True) + self.eps)