OpenLab-NLP commited on
Commit
93ae417
·
verified ·
1 Parent(s): ab4d9a8

Update V2.py

Browse files
Files changed (1) hide show
  1. V2.py +35 -40
V2.py CHANGED
@@ -129,53 +129,48 @@ ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)),
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
 
132
- class DynamicConv(layers.Layer):
133
- def __init__(self, d_model, k=7):
 
134
  super().__init__()
135
- assert k % 2 == 1
136
- self.k = k
137
- self.dense = layers.Dense(d_model, activation='silu')
138
- self.proj = layers.Dense(d_model)
139
- self.generator = layers.Dense(k, dtype='float32')
140
-
141
- self.ln1 = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
142
- self.ln2 = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
143
-
144
-
145
- def call(self, x):
146
- x_in = x
147
- x = tf.cast(x, tf.float32)
148
- x = self.ln1(x)
149
 
150
- B = tf.shape(x)[0]
151
- L = tf.shape(x)[1]
152
- D = tf.shape(x)[2]
 
153
 
154
- kernels = self.generator(self.dense(x))
155
- kernels = tf.nn.softmax(kernels, axis=-1)
 
 
156
 
157
- pad = (self.k - 1) // 2
158
- x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]])
159
 
160
- x_pad_4d = tf.expand_dims(x_pad, axis=1)
161
- patches = tf.image.extract_patches(
162
- images=x_pad_4d,
163
- sizes=[1,1,self.k,1],
164
- strides=[1,1,1,1],
165
- rates=[1,1,1,1],
166
- padding='VALID'
167
- )
168
- patches = tf.reshape(patches, [B, L, self.k, D])
169
 
170
- kernels_exp = tf.expand_dims(kernels, axis=-1)
171
- out = tf.reduce_sum(patches * kernels_exp, axis=2)
172
- out = self.proj(out)
173
- out = tf.nn.gelu(out)
174
- out = x + self.ln2(out)
 
 
175
 
176
- # 🔥 원래 dtype으로 돌려줌
177
- return tf.cast(out, x_in.dtype)
 
 
 
178
 
 
179
 
180
  class L2NormLayer(layers.Layer):
181
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
@@ -192,7 +187,7 @@ class SentenceEncoder(Model):
192
  self.embed = layers.Embedding(vocab_size, embed_dim)
193
  self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
194
  self.dropout = layers.Dropout(dropout_rate)
195
- self.blocks = [DynamicConv(d_model=embed_dim, k=7) for _ in range(4)]
196
  self.attn_pool = layers.Dense(1)
197
 
198
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
 
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
 
132
+
133
+ class MixerBlock(layers.Layer):
134
+ def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
135
  super().__init__()
136
+ self.seq_len = seq_len
137
+ self.dim = dim
138
+ self.token_mlp_dim = token_mlp_dim
139
+ self.channel_mlp_dim = channel_mlp_dim
 
 
 
 
 
 
 
 
 
 
140
 
141
+ self.ln1 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
142
+ # token-mixing MLP: operate over tokens => apply Dense on transposed axis
143
+ self.token_fc1 = layers.Dense(token_mlp_dim, activation='gelu', dtype=tf.float32)
144
+ self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
145
 
146
+ self.ln2 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
147
+ # channel-mixing MLP: operate per-token over channels
148
+ self.channel_fc1 = layers.Dense(channel_mlp_dim, activation='gelu', dtype=tf.float32)
149
+ self.channel_fc2 = layers.Dense(dim, dtype=tf.float32)
150
 
151
+ self.dropout = layers.Dropout(dropout)
 
152
 
153
+ def call(self, x, training=None):
154
+ # x: (B, L, D)
155
+ B = tf.shape(x)[0]
156
+ L = tf.shape(x)[1]
157
+ D = tf.shape(x)[2]
 
 
 
 
158
 
159
+ # Token-mixing
160
+ y = self.ln1(x) # (B, L, D)
161
+ y_t = tf.transpose(y, perm=[0,2,1]) # (B, D, L)
162
+ y_t = self.token_fc1(y_t) # (B, D, token_mlp_dim)
163
+ y_t = self.token_fc2(y_t) # (B, D, L)
164
+ y = tf.transpose(y_t, perm=[0,2,1]) # (B, L, D)
165
+ x = x + self.dropout(y, training=training)
166
 
167
+ # Channel-mixing
168
+ z = self.ln2(x)
169
+ z = self.channel_fc1(z)
170
+ z = self.channel_fc2(z)
171
+ x = x + self.dropout(z, training=training)
172
 
173
+ return x
174
 
175
  class L2NormLayer(layers.Layer):
176
  def __init__(self, axis=1, epsilon=1e-10, **kwargs):
 
187
  self.embed = layers.Embedding(vocab_size, embed_dim)
188
  self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
189
  self.dropout = layers.Dropout(dropout_rate)
190
+ self.blocks = [MixerBlock(seq_len=MAX_LEN, dim=embed_dim, token_mlp_dim=256, channel_mlp_dim=embed_dim, dropout=0.1) for _ in range(3)]
191
  self.attn_pool = layers.Dense(1)
192
 
193
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)