OpenLab-NLP commited on
Commit
83db00c
ยท
verified ยท
1 Parent(s): 93ae417

Update V2.py

Browse files
Files changed (1) hide show
  1. V2.py +31 -24
V2.py CHANGED
@@ -128,18 +128,20 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
-
132
-
133
  class MixerBlock(layers.Layer):
 
 
 
 
134
  def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
135
  super().__init__()
136
  self.seq_len = seq_len
137
  self.dim = dim
138
- self.token_mlp_dim = token_mlp_dim
139
- self.channel_mlp_dim = channel_mlp_dim
140
 
 
141
  self.ln1 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
142
  # token-mixing MLP: operate over tokens => apply Dense on transposed axis
 
143
  self.token_fc1 = layers.Dense(token_mlp_dim, activation='gelu', dtype=tf.float32)
144
  self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
145
 
@@ -151,26 +153,31 @@ class MixerBlock(layers.Layer):
151
  self.dropout = layers.Dropout(dropout)
152
 
153
  def call(self, x, training=None):
154
- # x: (B, L, D)
155
- B = tf.shape(x)[0]
156
- L = tf.shape(x)[1]
157
- D = tf.shape(x)[2]
158
-
159
- # Token-mixing
160
- y = self.ln1(x) # (B, L, D)
161
- y_t = tf.transpose(y, perm=[0,2,1]) # (B, D, L)
162
- y_t = self.token_fc1(y_t) # (B, D, token_mlp_dim)
163
- y_t = self.token_fc2(y_t) # (B, D, L)
164
- y = tf.transpose(y_t, perm=[0,2,1]) # (B, L, D)
165
- x = x + self.dropout(y, training=training)
166
-
167
- # Channel-mixing
168
- z = self.ln2(x)
169
- z = self.channel_fc1(z)
170
- z = self.channel_fc2(z)
171
- x = x + self.dropout(z, training=training)
172
-
173
- return x
 
 
 
 
 
174
 
175
  class L2NormLayer(layers.Layer):
176
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
 
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
 
 
131
  class MixerBlock(layers.Layer):
132
+ """
133
+ TPU / mixed-precision ์นœํ™”์  MLP-Mixer ๋ธ”๋ก (ํ† ํฐ-๋ฏน์‹ฑ + ์ฑ„๋„-๋ฏน์‹ฑ).
134
+ ๋‚ด๋ถ€ ์—ฐ์‚ฐ์€ float32๋กœ ์ˆ˜ํ–‰ํ•˜์—ฌ ์•ˆ์ •์„ฑ ํ™•๋ณด, ์ถœ๋ ฅ์€ ์ž…๋ ฅ dtype์œผ๋กœ ๋ณต์›.
135
+ """
136
  def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
137
  super().__init__()
138
  self.seq_len = seq_len
139
  self.dim = dim
 
 
140
 
141
+ # LayerNorm์€ float32๋กœ ์•ˆ์ •ํ™”
142
  self.ln1 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
143
  # token-mixing MLP: operate over tokens => apply Dense on transposed axis
144
+ # Dense๋„ float32๋กœ ๊ฐ•์ œ
145
  self.token_fc1 = layers.Dense(token_mlp_dim, activation='gelu', dtype=tf.float32)
146
  self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
147
 
 
153
  self.dropout = layers.Dropout(dropout)
154
 
155
  def call(self, x, training=None):
156
+ """
157
+ x: (B, L, D) โ€” dtype can be bfloat16/float32 depending on policy.
158
+ ๋‚ด๋ถ€ ๊ณ„์‚ฐ์€ float32๋กœ ์ˆ˜ํ–‰ํ•˜๊ณ , ๋ฐ˜ํ™˜์€ ์›๋ž˜ x.dtype์œผ๋กœ ์บ์ŠคํŒ….
159
+ """
160
+ orig_dtype = x.dtype
161
+ # ์•ˆ์ •์  ์—ฐ์‚ฐ์„ ์œ„ํ•ด float32๋กœ ๋ณ€ํ™˜
162
+ x_f = tf.cast(x, tf.float32) # (B, L, D)
163
+
164
+ # ---- Token-mixing (Dense on token axis) ----
165
+ y = self.ln1(x_f) # (B, L, D) in float32
166
+ y_t = tf.transpose(y, perm=[0, 2, 1]) # (B, D, L)
167
+ y_t = self.token_fc1(y_t) # (B, D, token_mlp_dim)
168
+ y_t = self.token_fc2(y_t) # (B, D, L)
169
+ y = tf.transpose(y_t, perm=[0, 2, 1]) # (B, L, D)
170
+ x_f = x_f + self.dropout(y, training=training)
171
+
172
+ # ---- Channel-mixing (per-token MLP) ----
173
+ z = self.ln2(x_f) # (B, L, D)
174
+ z = self.channel_fc1(z) # (B, L, channel_mlp_dim)
175
+ z = self.channel_fc2(z) # (B, L, D)
176
+ x_f = x_f + self.dropout(z, training=training)
177
+
178
+ # ์ตœ์ข…: ์›๋ž˜ dtype์œผ๋กœ ๋ณต์› (mixed-precision ์ด๋“ ์œ ์ง€)
179
+ return tf.cast(x_f, orig_dtype)
180
+
181
 
182
  class L2NormLayer(layers.Layer):
183
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):