Update Llama2.py
Browse files
Llama2.py
CHANGED
@@ -24,10 +24,15 @@ class ModelArgs:
|
|
24 |
weight_decay: float = 0.1
|
25 |
|
26 |
|
27 |
-
class RMSNorm:
|
28 |
def __init__(self, dim: int, eps: float):
|
29 |
self.eps = eps
|
30 |
-
self.weight =
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def _norm(self, x):
|
33 |
return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), -1, keepdims=True) + self.eps)
|
|
|
24 |
weight_decay: float = 0.1
|
25 |
|
26 |
|
27 |
+
class RMSNorm(tf.keras.layers.Layer):
|
28 |
def __init__(self, dim: int, eps: float):
|
29 |
self.eps = eps
|
30 |
+
self.weight = self.add_weight(
|
31 |
+
name='weight',
|
32 |
+
shape=(self.dim,),
|
33 |
+
initializer=tf.keras.initializers.Ones(),
|
34 |
+
trainable=True
|
35 |
+
)
|
36 |
|
37 |
def _norm(self, x):
|
38 |
return x * tf.math.rsqrt(tf.reduce_mean(tf.math.pow(x, 2), -1, keepdims=True) + self.eps)
|