Yuchan
commited on
Update Inference.py
Browse files- Inference.py +9 -13
Inference.py
CHANGED
|
@@ -27,7 +27,7 @@ def text_to_ids(text):
|
|
| 27 |
def ids_to_text(ids):
|
| 28 |
return sp.decode(ids)
|
| 29 |
|
| 30 |
-
max_len =
|
| 31 |
batch_size = 128
|
| 32 |
|
| 33 |
class Lo(layers.Layer):
|
|
@@ -118,7 +118,7 @@ class LoSoU(layers.Layer):
|
|
| 118 |
# x: (B, L, d_model) maybe bfloat16 or float32
|
| 119 |
# cast to float32 for all internal computations
|
| 120 |
x_f32 = tf.cast(x, tf.float32)
|
| 121 |
-
|
| 122 |
|
| 123 |
# Q, K, V
|
| 124 |
q = self.Q(x_f32) # (B, L, 96)
|
|
@@ -127,7 +127,7 @@ class LoSoU(layers.Layer):
|
|
| 127 |
|
| 128 |
# gating signals in (0,1)
|
| 129 |
g_q = tf.nn.sigmoid(q)
|
| 130 |
-
g_k = tf.nn.
|
| 131 |
|
| 132 |
# elementwise product -> bounded roughly [0,1]
|
| 133 |
score = g_q * g_k
|
|
@@ -162,12 +162,11 @@ class LoSoU(layers.Layer):
|
|
| 162 |
gated = tf.nn.silu(a) * b
|
| 163 |
out = self.O(gated)
|
| 164 |
|
| 165 |
-
out = self.norm(out
|
| 166 |
|
| 167 |
# cast back to original dtype for downstream layers
|
| 168 |
return tf.cast(out, x.dtype)
|
| 169 |
|
| 170 |
-
|
| 171 |
class Block(layers.Layer):
|
| 172 |
def __init__(self, d_model, hyper_n):
|
| 173 |
super().__init__()
|
|
@@ -181,23 +180,20 @@ class Block(layers.Layer):
|
|
| 181 |
class ReLaM(tf.keras.Model):
|
| 182 |
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
|
| 183 |
super().__init__()
|
| 184 |
-
self.token_embedding = layers.Embedding(vocab_size,
|
| 185 |
-
self.pos_embedding = layers.Embedding(max_seq_len,
|
| 186 |
-
self.blocks = [Block(d_model, hyper_n=
|
| 187 |
-
|
| 188 |
-
# LayerNormalization은 float32로 해서 정밀도 문제 방지
|
| 189 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
|
| 190 |
|
| 191 |
def call(self, x, training=False):
|
| 192 |
batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
|
| 193 |
positions = tf.range(seq_len)[tf.newaxis, :]
|
| 194 |
-
|
| 195 |
x = self.token_embedding(x) + self.pos_embedding(positions)
|
| 196 |
for block in self.blocks:
|
| 197 |
x = block(x)
|
| 198 |
-
|
| 199 |
x = self.ln_f(x)
|
| 200 |
-
|
| 201 |
embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
|
| 202 |
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
|
| 203 |
return tf.cast(logits, tf.float32)
|
|
|
|
| 27 |
def ids_to_text(ids):
|
| 28 |
return sp.decode(ids)
|
| 29 |
|
| 30 |
+
max_len = 230
|
| 31 |
batch_size = 128
|
| 32 |
|
| 33 |
class Lo(layers.Layer):
|
|
|
|
| 118 |
# x: (B, L, d_model) maybe bfloat16 or float32
|
| 119 |
# cast to float32 for all internal computations
|
| 120 |
x_f32 = tf.cast(x, tf.float32)
|
| 121 |
+
|
| 122 |
|
| 123 |
# Q, K, V
|
| 124 |
q = self.Q(x_f32) # (B, L, 96)
|
|
|
|
| 127 |
|
| 128 |
# gating signals in (0,1)
|
| 129 |
g_q = tf.nn.sigmoid(q)
|
| 130 |
+
g_k = tf.nn.tanh(k)
|
| 131 |
|
| 132 |
# elementwise product -> bounded roughly [0,1]
|
| 133 |
score = g_q * g_k
|
|
|
|
| 162 |
gated = tf.nn.silu(a) * b
|
| 163 |
out = self.O(gated)
|
| 164 |
|
| 165 |
+
out = self.norm(out)
|
| 166 |
|
| 167 |
# cast back to original dtype for downstream layers
|
| 168 |
return tf.cast(out, x.dtype)
|
| 169 |
|
|
|
|
| 170 |
class Block(layers.Layer):
|
| 171 |
def __init__(self, d_model, hyper_n):
|
| 172 |
super().__init__()
|
|
|
|
| 180 |
class ReLaM(tf.keras.Model):
|
| 181 |
def __init__(self, vocab_size, max_seq_len, d_model, n_layers, dropout_rate=0.1):
|
| 182 |
super().__init__()
|
| 183 |
+
self.token_embedding = layers.Embedding(vocab_size, 128)
|
| 184 |
+
self.pos_embedding = layers.Embedding(max_seq_len, 128)
|
| 185 |
+
self.blocks = [Block(d_model, hyper_n=1) for _ in range(n_layers)]
|
| 186 |
+
self.proj = layers.Dense(128)
|
|
|
|
| 187 |
self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype="float32")
|
| 188 |
|
| 189 |
def call(self, x, training=False):
|
| 190 |
batch_size, seq_len = tf.shape(x)[0], tf.shape(x)[1]
|
| 191 |
positions = tf.range(seq_len)[tf.newaxis, :]
|
|
|
|
| 192 |
x = self.token_embedding(x) + self.pos_embedding(positions)
|
| 193 |
for block in self.blocks:
|
| 194 |
x = block(x)
|
| 195 |
+
x = self.proj(x)
|
| 196 |
x = self.ln_f(x)
|
|
|
|
| 197 |
embedding_matrix = tf.cast(self.token_embedding.embeddings, x.dtype)
|
| 198 |
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
|
| 199 |
return tf.cast(logits, tf.float32)
|