OpenLab-NLP commited on
Commit
303d1df
·
verified ·
1 Parent(s): b37bc5a

Update V2.py

Browse files
Files changed (1) hide show
  1. V2.py +32 -64
V2.py CHANGED
@@ -129,54 +129,34 @@ ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)),
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
 
132
- class HyperConv1D(layers.Layer):
133
- def __init__(self, d_model, k=7, hyper_dim=128, dropout=0.0):
134
  super().__init__()
135
  assert k % 2 == 1
136
  self.k = k
137
- self.d_model = d_model
138
-
139
- # Input projection
140
- self.input_proj = layers.Dense(d_model, name="input_proj")
141
-
142
- # Dynamic kernel conv
143
- self.dynamic_dense = layers.Dense(d_model, activation='silu')
144
- self.dynamic_proj = layers.Dense(d_model)
145
- self.kernel_generator = layers.Dense(k, dtype='float32')
146
-
147
- # Hypernetwork for token-wise features
148
- self.hyper = tf.keras.Sequential([
149
- layers.Dense(hyper_dim, activation='gelu'),
150
- layers.Dense(d_model)
151
- ], name="hyper")
152
-
153
- # Attention pooling for global context
154
- self.attn_pool = layers.Dense(1)
155
-
156
- # LayerNorm + Dropout
157
- self.norm = layers.LayerNormalization()
158
- self.dropout = layers.Dropout(dropout)
159
- self.dense = layers.Dense(d_model)
160
-
161
- def call(self, x, training=None):
162
  x_in = x
163
- x_dtype = x.dtype # 입력 dtype 저장
 
164
 
165
- # 1) Input projection
166
- x_proj = self.input_proj(x) # (B, L, D)
167
- B = tf.shape(x_proj)[0]
168
- L = tf.shape(x_proj)[1]
169
- D = self.d_model
170
- pad = (self.k - 1) // 2
171
 
172
- # ------------------------------
173
- # 2) DynamicConv local mixing
174
- # ------------------------------
175
- kernels = self.kernel_generator(self.dynamic_dense(x_proj))
176
- kernels = tf.cast(kernels, x_proj.dtype)
177
  kernels = tf.nn.softmax(kernels, axis=-1)
178
 
179
- x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
 
 
180
  x_pad_4d = tf.expand_dims(x_pad, axis=1)
181
  patches = tf.image.extract_patches(
182
  images=x_pad_4d,
@@ -186,30 +166,15 @@ class HyperConv1D(layers.Layer):
186
  padding='VALID'
187
  )
188
  patches = tf.reshape(patches, [B, L, self.k, D])
 
189
  kernels_exp = tf.expand_dims(kernels, axis=-1)
190
- out_local = tf.reduce_sum(patches * kernels_exp, axis=2)
191
- out_local = self.dynamic_proj(out_local)
192
-
193
- # ------------------------------
194
- # 3) Global context via attention pooling (scale 제거)
195
- # ------------------------------
196
- h = self.hyper(x_proj)
197
- scores = tf.nn.softmax(self.attn_pool(h), axis=1) # (B, L, 1)
198
- global_context = tf.reduce_sum(h * scores, axis=1) # (B, D)
199
- # token-wise concat
200
- global_context_exp = tf.expand_dims(global_context, 1) * tf.ones([B, L, 1], dtype=x_proj.dtype)
201
- out_local = tf.concat([out_local, global_context_exp], axis=-1)
202
- out_local = self.dense(out_local) # dimension 맞춤
203
-
204
- # ------------------------------
205
- # 4) Residual + SiLU + LayerNorm
206
- # ------------------------------
207
- out = x_proj + out_local
208
- out = tf.nn.silu(out)
209
- out = self.norm(out)
210
- out = self.dropout(out, training=training)
211
-
212
- return tf.cast(out, x_dtype)
213
 
214
 
215
  class L2NormLayer(layers.Layer):
@@ -227,11 +192,14 @@ class SentenceEncoder(Model):
227
  self.embed = layers.Embedding(vocab_size, embed_dim)
228
  self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
229
  self.dropout = layers.Dropout(dropout_rate)
230
- self.blocks = [HyperConv1D(d_model=embed_dim, k=7, hyper_dim=256) for _ in range(4)]
231
  self.attn_pool = layers.Dense(1)
 
232
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
 
233
  self.latent = layers.Dense(latent_dim, activation=None)
234
  self.l2norm = L2NormLayer(axis=1)
 
235
  self.fc1 = layers.Dense(1152)
236
  self.fc2 = layers.Dense(embed_dim)
237
 
 
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
 
132
+ class DynamicConv(layers.Layer):
133
+ def __init__(self, d_model, k=7):
134
  super().__init__()
135
  assert k % 2 == 1
136
  self.k = k
137
+ self.dense = layers.Dense(d_model, activation='silu')
138
+ self.proj = layers.Dense(d_model)
139
+ self.generator = layers.Dense(k, dtype='float32')
140
+
141
+ self.ln1 = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
142
+ self.ln2 = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
143
+
144
+
145
+ def call(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  x_in = x
147
+ x = tf.cast(x, tf.float32)
148
+ x = self.ln1(x)
149
 
150
+ B = tf.shape(x)[0]
151
+ L = tf.shape(x)[1]
152
+ D = tf.shape(x)[2]
 
 
 
153
 
154
+ kernels = self.generator(self.dense(x))
 
 
 
 
155
  kernels = tf.nn.softmax(kernels, axis=-1)
156
 
157
+ pad = (self.k - 1) // 2
158
+ x_pad = tf.pad(x, [[0,0],[pad,pad],[0,0]])
159
+
160
  x_pad_4d = tf.expand_dims(x_pad, axis=1)
161
  patches = tf.image.extract_patches(
162
  images=x_pad_4d,
 
166
  padding='VALID'
167
  )
168
  patches = tf.reshape(patches, [B, L, self.k, D])
169
+
170
  kernels_exp = tf.expand_dims(kernels, axis=-1)
171
+ out = tf.reduce_sum(patches * kernels_exp, axis=2)
172
+ out = self.proj(out)
173
+ out = tf.nn.gelu(out)
174
+ out = x + self.ln2(out)
175
+
176
+ # 🔥 원래 dtype으로 돌려줌
177
+ return tf.cast(out, x_in.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
 
180
  class L2NormLayer(layers.Layer):
 
192
  self.embed = layers.Embedding(vocab_size, embed_dim)
193
  self.pos_embed = layers.Embedding(input_dim=max_len, output_dim=embed_dim)
194
  self.dropout = layers.Dropout(dropout_rate)
195
+ self.blocks = [DynamicConv(d_model=embed_dim, k=7) for _ in range(4)]
196
  self.attn_pool = layers.Dense(1)
197
+
198
  self.ln_f = layers.LayerNormalization(epsilon=1e-5, dtype=tf.float32)
199
+
200
  self.latent = layers.Dense(latent_dim, activation=None)
201
  self.l2norm = L2NormLayer(axis=1)
202
+
203
  self.fc1 = layers.Dense(1152)
204
  self.fc2 = layers.Dense(embed_dim)
205