Yuchan commited on
Commit
7cdfb1b
·
verified ·
1 Parent(s): 1065dfd

Update Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +9 -13
Inference.py CHANGED
@@ -27,7 +27,7 @@ def text_to_ids(text):
27
  def ids_to_text(ids):
28
  return sp.decode(ids)
29
 
30
- max_len = 100
31
  batch_size = 128
32
 
33
  class Lo(layers.Layer):
@@ -118,7 +118,7 @@ class LoSoU(layers.Layer):
118
  # x: (B, L, d_model) maybe bfloat16 or float32
119
  # cast to float32 for all internal computations
120
  x_f32 = tf.cast(x, tf.float32)
121
- residual = x_f32
122
 
123
  # Q, K, V
124
  q = self.Q(x_f32) # (B, L, 96)
@@ -127,7 +127,7 @@ class LoSoU(layers.Layer):
127
 
128
  # gating signals in (0,1)
129
  g_q = tf.nn.sigmoid(q)
130
- g_k = tf.nn.sigmoid(k)
131
 
132
  # elementwise product -> bounded roughly [0,1]
133
  score = g_q * g_k
@@ -162,12 +162,11 @@ class LoSoU(layers.Layer):
162
  gated = tf.nn.silu(a) * b
163
  out = self.O(gated)
164
 
165
- out = self.norm(out + residual)
166
 
167
  # cast back to original dtype for downstream layers
168
  return tf.cast(out, x.dtype)
169
 
170
-
171
  class Block(layers.Layer):
172
  def __init__(self, d_model, hyper_n):
173
  super().__init__()
@@ -181,23 +180,20 @@ class Block(layers.Layer):
181
  class ReLaM(tf.keras.Model):
182
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
183
  super().__init__()
184
- self.token_embedding = layers.Embedding(vocab_size, d_model)
185
- self.pos_embedding = layers.Embedding(max_seq_len, d_model)
186
- self.blocks = [Block(d_model, hyper_n=3) for _ in range(n_layers)]
187
-
188
- # LayerNormalization은 float32로 해서 정밀도 문제 방지
189
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
190
 
191
  def call(self, x, training=False):
192
  batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
193
  positions = tf.range(seq_len)[tf.newaxis, :]
194
-
195
  x = self.token_embedding(x) + self.pos_embedding(positions)
196
  for block in self.blocks:
197
  x = block(x)
198
-
199
  x = self.ln_f(x)
200
-
201
  embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
202
  logits = tf.matmul(x, embedding_matrix, transpose_b=True)
203
  return tf.cast(logits, tf.float32)
 
27
  def ids_to_text(ids):
28
  return sp.decode(ids)
29
 
30
+ max_len = 230
31
  batch_size = 128
32
 
33
  class Lo(layers.Layer):
 
118
  # x: (B, L, d_model) maybe bfloat16 or float32
119
  # cast to float32 for all internal computations
120
  x_f32 = tf.cast(x, tf.float32)
121
+
122
 
123
  # Q, K, V
124
  q = self.Q(x_f32) # (B, L, 96)
 
127
 
128
  # gating signals in (0,1)
129
  g_q = tf.nn.sigmoid(q)
130
+ g_k = tf.nn.tanh(k)
131
 
132
  # elementwise product -> bounded roughly [0,1]
133
  score = g_q * g_k
 
162
  gated = tf.nn.silu(a) * b
163
  out = self.O(gated)
164
 
165
+ out = self.norm(out)
166
 
167
  # cast back to original dtype for downstream layers
168
  return tf.cast(out, x.dtype)
169
 
 
170
  class Block(layers.Layer):
171
  def __init__(self, d_model, hyper_n):
172
  super().__init__()
 
180
  class ReLaM(tf.keras.Model):
181
  def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
182
  super().__init__()
183
+ self.token_embedding = layers.Embedding(vocab_size, 128)
184
+ self.pos_embedding = layers.Embedding(max_seq_len, 128)
185
+ self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
186
+ self.proj = layers.Dense(128)
 
187
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
188
 
189
  def call(self, x, training=False):
190
  batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
191
  positions = tf.range(seq_len)[tf.newaxis, :]
 
192
  x = self.token_embedding(x) + self.pos_embedding(positions)
193
  for block in self.blocks:
194
  x = block(x)
195
+ x = self.proj(x)
196
  x = self.ln_f(x)
 
197
  embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
198
  logits = tf.matmul(x, embedding_matrix, transpose_b=True)
199
  return tf.cast(logits, tf.float32)