Yuchan
commited on
Update Mo.py
Browse files
Mo.py
CHANGED
|
@@ -123,7 +123,7 @@ class SwiGLU(layers.Layer):
|
|
| 123 |
x_proj = self.proj(x)
|
| 124 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 125 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 126 |
-
|
| 127 |
class LoUScan(layers.Layer):
|
| 128 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
| 129 |
super().__init__()
|
|
@@ -137,7 +137,7 @@ class LoUScan(layers.Layer):
|
|
| 137 |
|
| 138 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 139 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 140 |
-
self.glu = SwiGLU(d_model,
|
| 141 |
|
| 142 |
def call(self, x):
|
| 143 |
x_f32 = tf.cast(x, tf.float32)
|
|
@@ -150,27 +150,32 @@ class LoUScan(layers.Layer):
|
|
| 150 |
|
| 151 |
g_q = (tf.nn.tanh(q) + 1.0) / 2.0
|
| 152 |
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 153 |
-
score = g_q * g_k # gating
|
| 154 |
-
|
| 155 |
-
#
|
| 156 |
-
def
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
outputs = tf.clip_by_value(outputs, -self.clip_value, self.clip_value)
|
| 170 |
out = self.norm(outputs + residual)
|
| 171 |
out = self.glu(out)
|
| 172 |
return tf.cast(out, x.dtype)
|
| 173 |
|
|
|
|
| 174 |
class Lo(layers.Layer):
|
| 175 |
def __init__(self, d_model):
|
| 176 |
super().__init__()
|
|
@@ -240,7 +245,7 @@ def masked_perplexity(y_true, y_pred, eps=0.1):
|
|
| 240 |
# ๋ชจ๋ธ ์์ฑ & ์ปดํ์ผ
|
| 241 |
# =======================
|
| 242 |
with strategy.scope():
|
| 243 |
-
model = CumaLM(vocab_size=vocab_size, max_seq_len=max_len,
|
| 244 |
dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
|
| 245 |
_ = model(dummy_input, training=False)
|
| 246 |
model.summary()
|
|
|
|
| 123 |
x_proj = self.proj(x)
|
| 124 |
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
| 125 |
return self.out(x_val * tf.nn.silu(x_gate))
|
| 126 |
+
|
| 127 |
class LoUScan(layers.Layer):
|
| 128 |
def __init__(self, d_model, clip_value=5.0, eps=1e-6):
|
| 129 |
super().__init__()
|
|
|
|
| 137 |
|
| 138 |
self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 139 |
self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
|
| 140 |
+
self.glu = SwiGLU(d_model, 3500) # ์ฌ์ฉ์ ์ ์ GLU
|
| 141 |
|
| 142 |
def call(self, x):
|
| 143 |
x_f32 = tf.cast(x, tf.float32)
|
|
|
|
| 150 |
|
| 151 |
g_q = (tf.nn.tanh(q) + 1.0) / 2.0
|
| 152 |
g_k = (tf.nn.tanh(k) + 1.0) / 2.0
|
| 153 |
+
score = g_q * g_k # element-wise gating
|
| 154 |
+
|
| 155 |
+
# ๋ฐฐ์น๋ณ ์์ฐจ์ scan ์ ์ฉ (์ธ๊ณผ์ )
|
| 156 |
+
def process_sequence(inputs):
|
| 157 |
+
score_seq, v_seq = inputs
|
| 158 |
+
seq_len = tf.shape(v_seq)[0]
|
| 159 |
+
init = tf.zeros_like(v_seq[0])
|
| 160 |
+
|
| 161 |
+
def step(carry, elems):
|
| 162 |
+
s_t, v_t = elems
|
| 163 |
+
new_sum = carry + s_t * v_t # ํ์ฌ๊น์ง ๋์
|
| 164 |
+
out = new_sum / tf.maximum(tf.reduce_sum(score_seq[:tf.shape(v_seq)[0]], axis=0, keepdims=True), self.eps)
|
| 165 |
+
return new_sum, out
|
| 166 |
+
|
| 167 |
+
_, outputs = tf.scan(step, (score_seq, v_seq), initializer=init)
|
| 168 |
+
return outputs
|
| 169 |
+
|
| 170 |
+
# ๋ฐฐ์น ์ฐจ์ ์ฒ๋ฆฌ
|
| 171 |
+
outputs = tf.map_fn(lambda inp: process_sequence(inp), (score, v), dtype=tf.float32)
|
| 172 |
+
|
| 173 |
outputs = tf.clip_by_value(outputs, -self.clip_value, self.clip_value)
|
| 174 |
out = self.norm(outputs + residual)
|
| 175 |
out = self.glu(out)
|
| 176 |
return tf.cast(out, x.dtype)
|
| 177 |
|
| 178 |
+
|
| 179 |
class Lo(layers.Layer):
|
| 180 |
def __init__(self, d_model):
|
| 181 |
super().__init__()
|
|
|
|
| 245 |
# ๋ชจ๋ธ ์์ฑ & ์ปดํ์ผ
|
| 246 |
# =======================
|
| 247 |
with strategy.scope():
|
| 248 |
+
model = CumaLM(vocab_size=vocab_size, max_seq_len=max_len, d_model=256, n_layers=1)
|
| 249 |
dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
|
| 250 |
_ = model(dummy_input, training=False)
|
| 251 |
model.summary()
|