Yuchan commited on
Commit
fff43f5
ยท
verified ยท
1 Parent(s): ebb1511

Update Mo.py

Browse files
Files changed (1) hide show
  1. Mo.py +24 -19
Mo.py CHANGED
@@ -123,7 +123,7 @@ class SwiGLU(layers.Layer):
123
  x_proj = self.proj(x)
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
-
127
  class LoUScan(layers.Layer):
128
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
  super().__init__()
@@ -137,7 +137,7 @@ class LoUScan(layers.Layer):
137
 
138
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
139
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
140
- self.glu = SwiGLU(d_model, 320)
141
 
142
  def call(self, x):
143
  x_f32 = tf.cast(x, tf.float32)
@@ -150,27 +150,32 @@ class LoUScan(layers.Layer):
150
 
151
  g_q = (tf.nn.tanh(q) + 1.0) / 2.0
152
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
153
- score = g_q * g_k # gating
154
-
155
- # tf.scan์œผ๋กœ ์ˆœ์ฐจ ๋ˆ„์ ํ•ฉ (์ธ๊ณผ์ )
156
- def step(carry, inputs):
157
- prev_sum = carry
158
- s, v_t = inputs
159
- new_sum = prev_sum + s * v_t
160
- # ์ •๊ทœํ™”
161
- out = new_sum / tf.maximum(tf.reduce_sum(score[:tf.shape(prev_sum)[0]], axis=0, keepdims=True), self.eps)
162
- return new_sum, out
163
-
164
- # ์ดˆ๊ธฐ๊ฐ’
165
- init = tf.zeros_like(v[0])
166
- _, outputs = tf.scan(step, (score, v), initializer=init, axis=0)
167
-
168
- # ์•ˆ์ •ํ™”
 
 
 
 
169
  outputs = tf.clip_by_value(outputs, -self.clip_value, self.clip_value)
170
  out = self.norm(outputs + residual)
171
  out = self.glu(out)
172
  return tf.cast(out, x.dtype)
173
 
 
174
  class Lo(layers.Layer):
175
  def __init__(self, d_model):
176
  super().__init__()
@@ -240,7 +245,7 @@ def masked_perplexity(y_true, y_pred, eps=0.1):
240
  # ๋ชจ๋ธ ์ƒ์„ฑ & ์ปดํŒŒ์ผ
241
  # =======================
242
  with strategy.scope():
243
- model = CumaLM(vocab_size=vocab_size, max_seq_len=max_len, d_ff=256, n_layers=1)
244
  dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
245
  _ = model(dummy_input, training=False)
246
  model.summary()
 
123
  x_proj = self.proj(x)
124
  x_val, x_gate = tf.split(x_proj, 2, axis=-1)
125
  return self.out(x_val * tf.nn.silu(x_gate))
126
+
127
  class LoUScan(layers.Layer):
128
  def __init__(self, d_model, clip_value=5.0, eps=1e-6):
129
  super().__init__()
 
137
 
138
  self.norm = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
139
  self.norm1 = layers.LayerNormalization(epsilon=1e-5, dtype='float32')
140
+ self.glu = SwiGLU(d_model, 3500) # ์‚ฌ์šฉ์ž ์ •์˜ GLU
141
 
142
  def call(self, x):
143
  x_f32 = tf.cast(x, tf.float32)
 
150
 
151
  g_q = (tf.nn.tanh(q) + 1.0) / 2.0
152
  g_k = (tf.nn.tanh(k) + 1.0) / 2.0
153
+ score = g_q * g_k # element-wise gating
154
+
155
+ # ๋ฐฐ์น˜๋ณ„ ์ˆœ์ฐจ์  scan ์ ์šฉ (์ธ๊ณผ์ )
156
+ def process_sequence(inputs):
157
+ score_seq, v_seq = inputs
158
+ seq_len = tf.shape(v_seq)[0]
159
+ init = tf.zeros_like(v_seq[0])
160
+
161
+ def step(carry, elems):
162
+ s_t, v_t = elems
163
+ new_sum = carry + s_t * v_t # ํ˜„์žฌ๊นŒ์ง€ ๋ˆ„์ 
164
+ out = new_sum / tf.maximum(tf.reduce_sum(score_seq[:tf.shape(v_seq)[0]], axis=0, keepdims=True), self.eps)
165
+ return new_sum, out
166
+
167
+ _, outputs = tf.scan(step, (score_seq, v_seq), initializer=init)
168
+ return outputs
169
+
170
+ # ๋ฐฐ์น˜ ์ฐจ์› ์ฒ˜๋ฆฌ
171
+ outputs = tf.map_fn(lambda inp: process_sequence(inp), (score, v), dtype=tf.float32)
172
+
173
  outputs = tf.clip_by_value(outputs, -self.clip_value, self.clip_value)
174
  out = self.norm(outputs + residual)
175
  out = self.glu(out)
176
  return tf.cast(out, x.dtype)
177
 
178
+
179
  class Lo(layers.Layer):
180
  def __init__(self, d_model):
181
  super().__init__()
 
245
  # ๋ชจ๋ธ ์ƒ์„ฑ & ์ปดํŒŒ์ผ
246
  # =======================
247
  with strategy.scope():
248
+ model = CumaLM(vocab_size=vocab_size, max_seq_len=max_len, d_model=256, n_layers=1)
249
  dummy_input = tf.zeros((batch_size, max_len), dtype=tf.int32)
250
  _ = model(dummy_input, training=False)
251
  model.summary()