OpenLab-NLP commited on
Commit
dba50a9
·
verified ·
1 Parent(s): 1311b89

Update openlem-tpu.py

Browse files
Files changed (1) hide show
  1. openlem-tpu.py +45 -31
openlem-tpu.py CHANGED
@@ -135,40 +135,48 @@ class DynamicConv(layers.Layer):
135
  self.k = k
136
  self.dense = layers.Dense(d_model, activation='gelu')
137
  self.proj = layers.Dense(d_model)
 
138
  self.generator = layers.Dense(k, dtype='float32')
 
139
  def call(self, x):
140
  x_in = x
141
- x = tf.cast(x, tf.float32)
142
-
143
- B = tf.shape(x)[0]
144
- L = tf.shape(x)[1]
145
- D = tf.shape(x)[2]
146
-
147
- kernels = self.generator(self.dense(x))
148
- kernels = tf.nn.softmax(kernels, axis=-1)
149
 
 
150
  pad = (self.k - 1) // 2
151
- x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]])
 
152
 
153
- x_pad_4d = tf.expand_dims(x_pad, axis=1)
154
  patches = tf.image.extract_patches(
155
  images=x_pad_4d,
156
- sizes=[1,1,self.k,1],
157
- strides=[1,1,1,1],
158
- rates=[1,1,1,1],
159
  padding='VALID'
160
- )
 
 
 
 
 
161
  patches = tf.reshape(patches, [B, L, self.k, D])
162
 
163
- kernels_exp = tf.expand_dims(kernels, axis=-1)
164
- out = tf.reduce_sum(patches * kernels_exp, axis=2)
 
 
 
 
165
  out = self.proj(out)
166
 
167
- # 🔥 원래 dtype으로 돌려줌
168
  return tf.cast(out, x_in.dtype)
169
 
 
 
 
 
170
  class MixerBlock(layers.Layer):
171
- def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
172
  super().__init__()
173
  self.dim = dim
174
 
@@ -176,37 +184,43 @@ class MixerBlock(layers.Layer):
176
  self.ln_local = layers.LayerNormalization(epsilon=1e-6)
177
  self.ln_channel = layers.LayerNormalization(epsilon=1e-6)
178
 
179
- # __init__
180
- self.token_fc1 = layers.Dense(seq_len)
181
  self.token_fc2 = layers.Dense(seq_len)
182
 
183
- # Channel Mixer
184
  self.ch_fc1 = layers.Dense(self.dim * 4)
185
  self.ch_fc2 = layers.Dense(self.dim)
186
 
187
- self.conv1 = DynamicConv(d_model=dim, k=7)
 
 
188
  def call(self, x, training=None):
189
- # 1. Local mixing 먼저
190
  y = self.ln_local(x)
191
  y = self.conv1(y)
192
  x = x + y
193
 
194
- # 2. 약한 global token mixing
195
  y = self.ln_token(x)
196
- y_t = tf.transpose(y, [0,2,1])
197
- # call
198
- y_t = tf.nn.gelu(self.token_fc1(y_t))
199
- y_t = self.token_fc2(y_t)
200
-
201
- y = tf.transpose(y_t, [0,2,1])
202
  x = x + y
203
 
204
- # 3. Channel mixing
205
  y = self.ln_channel(x)
206
  a, b = tf.split(self.ch_fc1(y), 2, axis=-1)
207
  y = self.ch_fc2(a * tf.nn.gelu(b))
208
  x = x + y
209
 
 
 
 
 
 
210
 
211
  class L2NormLayer(layers.Layer):
212
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
 
135
  self.k = k
136
  self.dense = layers.Dense(d_model, activation='gelu')
137
  self.proj = layers.Dense(d_model)
138
+ # generator should produce k weights per token; softmax in float32 for stability
139
  self.generator = layers.Dense(k, dtype='float32')
140
+
141
  def call(self, x):
142
  x_in = x
143
+ x = tf.cast(x, tf.float32) # softmax 안전하게 float32
 
 
 
 
 
 
 
144
 
145
+ # padding + extract patches
146
  pad = (self.k - 1) // 2
147
+ x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]]) # [B, L+2pad, D]
148
+ x_pad_4d = tf.expand_dims(x_pad, axis=1) # [B, 1, L+2pad, D]
149
 
 
150
  patches = tf.image.extract_patches(
151
  images=x_pad_4d,
152
+ sizes=[1, 1, self.k, 1],
153
+ strides=[1, 1, 1, 1],
154
+ rates=[1, 1, 1, 1],
155
  padding='VALID'
156
+ ) # shape: [B, 1, L, k*D]
157
+
158
+ # reshape -> [B, L, k, D]
159
+ B = tf.shape(patches)[0]
160
+ L = tf.shape(patches)[2]
161
+ D = tf.shape(x)[2]
162
  patches = tf.reshape(patches, [B, L, self.k, D])
163
 
164
+ # generate kernels per token
165
+ kernels = self.generator(self.dense(x)) # [B, L, k], in float32
166
+ kernels = tf.nn.softmax(kernels, axis=-1) # [B, L, k]
167
+
168
+ kernels_exp = tf.expand_dims(kernels, axis=-1) # [B, L, k, 1]
169
+ out = tf.reduce_sum(patches * kernels_exp, axis=2) # [B, L, D]
170
  out = self.proj(out)
171
 
 
172
  return tf.cast(out, x_in.dtype)
173
 
174
+ def compute_output_shape(self, input_shape):
175
+ return input_shape
176
+
177
+
178
  class MixerBlock(layers.Layer):
179
+ def __init__(self, seq_len, dim, token_mlp_dim=None, channel_mlp_dim=None, dropout=0.0):
180
  super().__init__()
181
  self.dim = dim
182
 
 
184
  self.ln_local = layers.LayerNormalization(epsilon=1e-6)
185
  self.ln_channel = layers.LayerNormalization(epsilon=1e-6)
186
 
187
+ # NOTE: token_fc1 must output 2 * seq_len to allow split()
188
+ self.token_fc1 = layers.Dense(seq_len * 2)
189
  self.token_fc2 = layers.Dense(seq_len)
190
 
191
+ # Channel Mixer (GLU style)
192
  self.ch_fc1 = layers.Dense(self.dim * 4)
193
  self.ch_fc2 = layers.Dense(self.dim)
194
 
195
+ # local dynamic conv
196
+ self.conv1 = DynamicConv(d_model=dim, k=5)
197
+
198
  def call(self, x, training=None):
199
+ # 1) Local mixing first
200
  y = self.ln_local(x)
201
  y = self.conv1(y)
202
  x = x + y
203
 
204
+ # 2) (Weak) Global token mixing
205
  y = self.ln_token(x)
206
+ y_t = tf.transpose(y, perm=[0, 2, 1]) # [B, D, L]
207
+ y_t = self.token_fc1(y_t) # [B, D, 2*L]
208
+ a, b = tf.split(y_t, 2, axis=-1) # split on last dim
209
+ y_t = self.token_fc2(a * tf.nn.gelu(b)) # [B, D, L]
210
+ y = tf.transpose(y_t, perm=[0, 2, 1]) # [B, L, D]
 
211
  x = x + y
212
 
213
+ # 3) Channel mixer (GLU)
214
  y = self.ln_channel(x)
215
  a, b = tf.split(self.ch_fc1(y), 2, axis=-1)
216
  y = self.ch_fc2(a * tf.nn.gelu(b))
217
  x = x + y
218
 
219
+ return x
220
+
221
+ def compute_output_shape(self, input_shape):
222
+ return input_shape
223
+
224
 
225
  class L2NormLayer(layers.Layer):
226
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):