Update openlem-tpu.py
Browse files- openlem-tpu.py +45 -31
openlem-tpu.py
CHANGED
|
@@ -135,40 +135,48 @@ class DynamicConv(layers.Layer):
|
|
| 135 |
self.k = k
|
| 136 |
self.dense = layers.Dense(d_model, activation='gelu')
|
| 137 |
self.proj = layers.Dense(d_model)
|
|
|
|
| 138 |
self.generator = layers.Dense(k, dtype='float32')
|
|
|
|
| 139 |
def call(self, x):
|
| 140 |
x_in = x
|
| 141 |
-
x = tf.cast(x, tf.float32)
|
| 142 |
-
|
| 143 |
-
B = tf.shape(x)[0]
|
| 144 |
-
L = tf.shape(x)[1]
|
| 145 |
-
D = tf.shape(x)[2]
|
| 146 |
-
|
| 147 |
-
kernels = self.generator(self.dense(x))
|
| 148 |
-
kernels = tf.nn.softmax(kernels, axis=-1)
|
| 149 |
|
|
|
|
| 150 |
pad = (self.k - 1) // 2
|
| 151 |
-
x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]])
|
|
|
|
| 152 |
|
| 153 |
-
x_pad_4d = tf.expand_dims(x_pad, axis=1)
|
| 154 |
patches = tf.image.extract_patches(
|
| 155 |
images=x_pad_4d,
|
| 156 |
-
sizes=[1,1,self.k,1],
|
| 157 |
-
strides=[1,1,1,1],
|
| 158 |
-
rates=[1,1,1,1],
|
| 159 |
padding='VALID'
|
| 160 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
patches = tf.reshape(patches, [B, L, self.k, D])
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
out = self.proj(out)
|
| 166 |
|
| 167 |
-
# 🔥 원래 dtype으로 돌려줌
|
| 168 |
return tf.cast(out, x_in.dtype)
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
class MixerBlock(layers.Layer):
|
| 171 |
-
def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
|
| 172 |
super().__init__()
|
| 173 |
self.dim = dim
|
| 174 |
|
|
@@ -176,37 +184,43 @@ class MixerBlock(layers.Layer):
|
|
| 176 |
self.ln_local = layers.LayerNormalization(epsilon=1e-6)
|
| 177 |
self.ln_channel = layers.LayerNormalization(epsilon=1e-6)
|
| 178 |
|
| 179 |
-
#
|
| 180 |
-
self.token_fc1 = layers.Dense(seq_len)
|
| 181 |
self.token_fc2 = layers.Dense(seq_len)
|
| 182 |
|
| 183 |
-
# Channel Mixer
|
| 184 |
self.ch_fc1 = layers.Dense(self.dim * 4)
|
| 185 |
self.ch_fc2 = layers.Dense(self.dim)
|
| 186 |
|
| 187 |
-
|
|
|
|
|
|
|
| 188 |
def call(self, x, training=None):
|
| 189 |
-
# 1
|
| 190 |
y = self.ln_local(x)
|
| 191 |
y = self.conv1(y)
|
| 192 |
x = x + y
|
| 193 |
|
| 194 |
-
# 2
|
| 195 |
y = self.ln_token(x)
|
| 196 |
-
y_t = tf.transpose(y, [0,2,1])
|
| 197 |
-
#
|
| 198 |
-
|
| 199 |
-
y_t = self.token_fc2(
|
| 200 |
-
|
| 201 |
-
y = tf.transpose(y_t, [0,2,1])
|
| 202 |
x = x + y
|
| 203 |
|
| 204 |
-
# 3
|
| 205 |
y = self.ln_channel(x)
|
| 206 |
a, b = tf.split(self.ch_fc1(y), 2, axis=-1)
|
| 207 |
y = self.ch_fc2(a * tf.nn.gelu(b))
|
| 208 |
x = x + y
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
class L2NormLayer(layers.Layer):
|
| 212 |
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
|
|
|
|
| 135 |
self.k = k
|
| 136 |
self.dense = layers.Dense(d_model, activation='gelu')
|
| 137 |
self.proj = layers.Dense(d_model)
|
| 138 |
+
# generator should produce k weights per token; softmax in float32 for stability
|
| 139 |
self.generator = layers.Dense(k, dtype='float32')
|
| 140 |
+
|
| 141 |
def call(self, x):
|
| 142 |
x_in = x
|
| 143 |
+
x = tf.cast(x, tf.float32) # softmax 안전하게 float32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
# padding + extract patches
|
| 146 |
pad = (self.k - 1) // 2
|
| 147 |
+
x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]]) # [B, L+2pad, D]
|
| 148 |
+
x_pad_4d = tf.expand_dims(x_pad, axis=1) # [B, 1, L+2pad, D]
|
| 149 |
|
|
|
|
| 150 |
patches = tf.image.extract_patches(
|
| 151 |
images=x_pad_4d,
|
| 152 |
+
sizes=[1, 1, self.k, 1],
|
| 153 |
+
strides=[1, 1, 1, 1],
|
| 154 |
+
rates=[1, 1, 1, 1],
|
| 155 |
padding='VALID'
|
| 156 |
+
) # shape: [B, 1, L, k*D]
|
| 157 |
+
|
| 158 |
+
# reshape -> [B, L, k, D]
|
| 159 |
+
B = tf.shape(patches)[0]
|
| 160 |
+
L = tf.shape(patches)[2]
|
| 161 |
+
D = tf.shape(x)[2]
|
| 162 |
patches = tf.reshape(patches, [B, L, self.k, D])
|
| 163 |
|
| 164 |
+
# generate kernels per token
|
| 165 |
+
kernels = self.generator(self.dense(x)) # [B, L, k], in float32
|
| 166 |
+
kernels = tf.nn.softmax(kernels, axis=-1) # [B, L, k]
|
| 167 |
+
|
| 168 |
+
kernels_exp = tf.expand_dims(kernels, axis=-1) # [B, L, k, 1]
|
| 169 |
+
out = tf.reduce_sum(patches * kernels_exp, axis=2) # [B, L, D]
|
| 170 |
out = self.proj(out)
|
| 171 |
|
|
|
|
| 172 |
return tf.cast(out, x_in.dtype)
|
| 173 |
|
| 174 |
+
def compute_output_shape(self, input_shape):
|
| 175 |
+
return input_shape
|
| 176 |
+
|
| 177 |
+
|
| 178 |
class MixerBlock(layers.Layer):
|
| 179 |
+
def __init__(self, seq_len, dim, token_mlp_dim=None, channel_mlp_dim=None, dropout=0.0):
|
| 180 |
super().__init__()
|
| 181 |
self.dim = dim
|
| 182 |
|
|
|
|
| 184 |
self.ln_local = layers.LayerNormalization(epsilon=1e-6)
|
| 185 |
self.ln_channel = layers.LayerNormalization(epsilon=1e-6)
|
| 186 |
|
| 187 |
+
# NOTE: token_fc1 must output 2 * seq_len to allow split()
|
| 188 |
+
self.token_fc1 = layers.Dense(seq_len * 2)
|
| 189 |
self.token_fc2 = layers.Dense(seq_len)
|
| 190 |
|
| 191 |
+
# Channel Mixer (GLU style)
|
| 192 |
self.ch_fc1 = layers.Dense(self.dim * 4)
|
| 193 |
self.ch_fc2 = layers.Dense(self.dim)
|
| 194 |
|
| 195 |
+
# local dynamic conv
|
| 196 |
+
self.conv1 = DynamicConv(d_model=dim, k=5)
|
| 197 |
+
|
| 198 |
def call(self, x, training=None):
|
| 199 |
+
# 1) Local mixing first
|
| 200 |
y = self.ln_local(x)
|
| 201 |
y = self.conv1(y)
|
| 202 |
x = x + y
|
| 203 |
|
| 204 |
+
# 2) (Weak) Global token mixing
|
| 205 |
y = self.ln_token(x)
|
| 206 |
+
y_t = tf.transpose(y, perm=[0, 2, 1]) # [B, D, L]
|
| 207 |
+
y_t = self.token_fc1(y_t) # [B, D, 2*L]
|
| 208 |
+
a, b = tf.split(y_t, 2, axis=-1) # split on last dim
|
| 209 |
+
y_t = self.token_fc2(a * tf.nn.gelu(b)) # [B, D, L]
|
| 210 |
+
y = tf.transpose(y_t, perm=[0, 2, 1]) # [B, L, D]
|
|
|
|
| 211 |
x = x + y
|
| 212 |
|
| 213 |
+
# 3) Channel mixer (GLU)
|
| 214 |
y = self.ln_channel(x)
|
| 215 |
a, b = tf.split(self.ch_fc1(y), 2, axis=-1)
|
| 216 |
y = self.ch_fc2(a * tf.nn.gelu(b))
|
| 217 |
x = x + y
|
| 218 |
|
| 219 |
+
return x
|
| 220 |
+
|
| 221 |
+
def compute_output_shape(self, input_shape):
|
| 222 |
+
return input_shape
|
| 223 |
+
|
| 224 |
|
| 225 |
class L2NormLayer(layers.Layer):
|
| 226 |
def __init__(self, axis=1, epsilon=1e-10, **kwargs):
|