OpenLab-NLP commited on
Commit
27a17ff
ยท
verified ยท
1 Parent(s): 83db00c

Update V2.py

Browse files
Files changed (1) hide show
  1. V2.py +25 -32
V2.py CHANGED
@@ -128,20 +128,18 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
 
 
131
  class MixerBlock(layers.Layer):
132
- """
133
- TPU / mixed-precision ์นœํ™”์  MLP-Mixer ๋ธ”๋ก (ํ† ํฐ-๋ฏน์‹ฑ + ์ฑ„๋„-๋ฏน์‹ฑ).
134
- ๋‚ด๋ถ€ ์—ฐ์‚ฐ์€ float32๋กœ ์ˆ˜ํ–‰ํ•˜์—ฌ ์•ˆ์ •์„ฑ ํ™•๋ณด, ์ถœ๋ ฅ์€ ์ž…๋ ฅ dtype์œผ๋กœ ๋ณต์›.
135
- """
136
  def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
137
  super().__init__()
138
  self.seq_len = seq_len
139
  self.dim = dim
 
 
140
 
141
- # LayerNorm์€ float32๋กœ ์•ˆ์ •ํ™”
142
  self.ln1 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
143
  # token-mixing MLP: operate over tokens => apply Dense on transposed axis
144
- # Dense๋„ float32๋กœ ๊ฐ•์ œ
145
  self.token_fc1 = layers.Dense(token_mlp_dim, activation='gelu', dtype=tf.float32)
146
  self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
147
 
@@ -153,31 +151,26 @@ class MixerBlock(layers.Layer):
153
  self.dropout = layers.Dropout(dropout)
154
 
155
  def call(self, x, training=None):
156
- """
157
- x: (B, L, D) โ€” dtype can be bfloat16/float32 depending on policy.
158
- ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์ˆ˜ํ–‰ํ•˜๊ณ , ๋ฐ˜ํ™˜์€ ์›๋ž˜ x.dtype์œผ๋กœ ์บ์ŠคํŒ….
159
- """
160
- orig_dtype = x.dtype
161
- # ์•ˆ์ •์  ์—ฐ์‚ฐ์„ ์œ„ํ•ด float32๋กœ ๋ณ€ํ™˜
162
- x_f = tf.cast(x, tf.float32) # (B, L, D)
163
-
164
- # ---- Token-mixing (Dense on token axis) ----
165
- y = self.ln1(x_f) # (B, L, D) in float32
166
- y_t = tf.transpose(y, perm=[0, 2, 1]) # (B, D, L)
167
- y_t = self.token_fc1(y_t) # (B, D, token_mlp_dim)
168
- y_t = self.token_fc2(y_t) # (B, D, L)
169
- y = tf.transpose(y_t, perm=[0, 2, 1]) # (B, L, D)
170
- x_f = x_f + self.dropout(y, training=training)
171
-
172
- # ---- Channel-mixing (per-token MLP) ----
173
- z = self.ln2(x_f) # (B, L, D)
174
- z = self.channel_fc1(z) # (B, L, channel_mlp_dim)
175
- z = self.channel_fc2(z) # (B, L, D)
176
- x_f = x_f + self.dropout(z, training=training)
177
-
178
- # ์ตœ์ข…: ์›๋ž˜ dtype์œผ๋กœ ๋ณต์› (mixed-precision ์ด๋“ ์œ ์ง€)
179
- return tf.cast(x_f, orig_dtype)
180
-
181
 
182
  class L2NormLayer(layers.Layer):
183
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
@@ -295,4 +288,4 @@ history = model.fit(ds, epochs=EPOCHS, steps_per_epoch=steps_per_epoch, verbose=
295
 
296
  # encoder ๊ฐ€์ค‘์น˜ ์ €์žฅ
297
  encoder.save_weights("encoder_fit.weights.h5")
298
- print("Training finished and weights saved.")
 
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
+
132
+
133
  class MixerBlock(layers.Layer):
 
 
 
 
134
  def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
135
  super().__init__()
136
  self.seq_len = seq_len
137
  self.dim = dim
138
+ self.token_mlp_dim = token_mlp_dim
139
+ self.channel_mlp_dim = channel_mlp_dim
140
 
 
141
  self.ln1 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
142
  # token-mixing MLP: operate over tokens => apply Dense on transposed axis
 
143
  self.token_fc1 = layers.Dense(token_mlp_dim, activation='gelu', dtype=tf.float32)
144
  self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
145
 
 
151
  self.dropout = layers.Dropout(dropout)
152
 
153
  def call(self, x, training=None):
154
+ # x: (B, L, D)
155
+ B = tf.shape(x)[0]
156
+ L = tf.shape(x)[1]
157
+ D = tf.shape(x)[2]
158
+
159
+ # Token-mixing
160
+ y = self.ln1(x) # (B, L, D)
161
+ y_t = tf.transpose(y, perm=[0,2,1]) # (B, D, L)
162
+ y_t = self.token_fc1(y_t) # (B, D, token_mlp_dim)
163
+ y_t = self.token_fc2(y_t) # (B, D, L)
164
+ y = tf.transpose(y_t, perm=[0,2,1]) # (B, L, D)
165
+ x = x + self.dropout(y, training=training)
166
+
167
+ # Channel-mixing
168
+ z = self.ln2(x)
169
+ z = self.channel_fc1(z)
170
+ z = self.channel_fc2(z)
171
+ x = x + self.dropout(z, training=training)
172
+
173
+ return x
 
 
 
 
 
174
 
175
  class L2NormLayer(layers.Layer):
176
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
 
288
 
289
  # encoder ๊ฐ€์ค‘์น˜ ์ €์žฅ
290
  encoder.save_weights("encoder_fit.weights.h5")
291
+ print("Training finished and weights saved.")