OpenLab-NLP commited on
Commit
b37bc5a
·
verified ·
1 Parent(s): f80b20b

Update V2.py

Browse files
Files changed (1) hide show
  1. V2.py +30 -33
V2.py CHANGED
@@ -128,6 +128,7 @@ ds = ds.batch(BATCH_SIZE, drop_remainder=True)
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
 
131
  class HyperConv1D(layers.Layer):
132
  def __init__(self, d_model, k=7, hyper_dim=128, dropout=0.0):
133
  super().__init__()
@@ -143,70 +144,66 @@ class HyperConv1D(layers.Layer):
143
  self.dynamic_proj = layers.Dense(d_model)
144
  self.kernel_generator = layers.Dense(k, dtype='float32')
145
 
146
- # Hypernetwork: token-wise transform before pooling
147
  self.hyper = tf.keras.Sequential([
148
  layers.Dense(hyper_dim, activation='gelu'),
149
  layers.Dense(d_model)
150
  ], name="hyper")
151
 
 
152
  self.attn_pool = layers.Dense(1)
153
- self.scale_dense = layers.Dense(d_model)
154
 
 
155
  self.norm = layers.LayerNormalization()
156
  self.dropout = layers.Dropout(dropout)
 
157
 
158
  def call(self, x, training=None):
159
  x_in = x
160
  x_dtype = x.dtype # 입력 dtype 저장
161
 
162
- # 1) input projection
163
  x_proj = self.input_proj(x) # (B, L, D)
164
-
165
  B = tf.shape(x_proj)[0]
166
  L = tf.shape(x_proj)[1]
167
  D = self.d_model
168
  pad = (self.k - 1) // 2
169
 
170
- # ------------------------------
171
- # 2) DynamicConv local mixing
172
- # ------------------------------
173
- # kernels 생성 후 x_proj dtype으로 맞춤
174
  kernels = self.kernel_generator(self.dynamic_dense(x_proj))
175
  kernels = tf.cast(kernels, x_proj.dtype)
176
  kernels = tf.nn.softmax(kernels, axis=-1)
177
 
178
- # padding & patch 추출
179
  x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
180
- x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
181
  patches = tf.image.extract_patches(
182
- images=x_pad_4d,
183
- sizes=[1,1,self.k,1],
184
- strides=[1,1,1,1],
185
- rates=[1,1,1,1],
186
- padding='VALID'
187
- )
188
  patches = tf.reshape(patches, [B, L, self.k, D])
189
-
190
- # kernels shape 맞추기
191
  kernels_exp = tf.expand_dims(kernels, axis=-1)
192
- out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
193
  out_local = self.dynamic_proj(out_local)
194
 
195
- # ------------------------------
196
- # 3) Hyper scaling
197
- # ------------------------------
198
  h = self.hyper(x_proj)
199
- global_z = self.attn_pool(h)
200
- global_z = tf.nn.softmax(global_z, axis=1)
201
- global_z = tf.reduce_sum(h * global_z, axis=1)
202
-
203
- scale = tf.expand_dims(tf.nn.sigmoid(self.scale_dense(global_z)), 1)
204
- scale = tf.cast(scale, x_proj.dtype) # dtype 맞춤
205
- out_local = out_local * scale
206
-
207
- # ------------------------------
208
- # 4) Residual + SiLU + LayerNorm
209
- # ------------------------------
210
  out = x_proj + out_local
211
  out = tf.nn.silu(out)
212
  out = self.norm(out)
 
128
  ds = ds.map(lambda v1, v2: ((v1, v2), tf.zeros([BATCH_SIZE], dtype=tf.float32)), num_parallel_calls=tf.data.AUTOTUNE)
129
  ds = ds.prefetch(tf.data.AUTOTUNE)
130
 
131
+
132
  class HyperConv1D(layers.Layer):
133
  def __init__(self, d_model, k=7, hyper_dim=128, dropout=0.0):
134
  super().__init__()
 
144
  self.dynamic_proj = layers.Dense(d_model)
145
  self.kernel_generator = layers.Dense(k, dtype='float32')
146
 
147
+ # Hypernetwork for token-wise features
148
  self.hyper = tf.keras.Sequential([
149
  layers.Dense(hyper_dim, activation='gelu'),
150
  layers.Dense(d_model)
151
  ], name="hyper")
152
 
153
+ # Attention pooling for global context
154
  self.attn_pool = layers.Dense(1)
 
155
 
156
+ # LayerNorm + Dropout
157
  self.norm = layers.LayerNormalization()
158
  self.dropout = layers.Dropout(dropout)
159
+ self.dense = layers.Dense(d_model)
160
 
161
  def call(self, x, training=None):
162
  x_in = x
163
  x_dtype = x.dtype # 입력 dtype 저장
164
 
165
+ # 1) Input projection
166
  x_proj = self.input_proj(x) # (B, L, D)
 
167
  B = tf.shape(x_proj)[0]
168
  L = tf.shape(x_proj)[1]
169
  D = self.d_model
170
  pad = (self.k - 1) // 2
171
 
172
+ # ------------------------------
173
+ # 2) DynamicConv local mixing
174
+ # ------------------------------
 
175
  kernels = self.kernel_generator(self.dynamic_dense(x_proj))
176
  kernels = tf.cast(kernels, x_proj.dtype)
177
  kernels = tf.nn.softmax(kernels, axis=-1)
178
 
 
179
  x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
180
+ x_pad_4d = tf.expand_dims(x_pad, axis=1)
181
  patches = tf.image.extract_patches(
182
+ images=x_pad_4d,
183
+ sizes=[1,1,self.k,1],
184
+ strides=[1,1,1,1],
185
+ rates=[1,1,1,1],
186
+ padding='VALID'
187
+ )
188
  patches = tf.reshape(patches, [B, L, self.k, D])
 
 
189
  kernels_exp = tf.expand_dims(kernels, axis=-1)
190
+ out_local = tf.reduce_sum(patches * kernels_exp, axis=2)
191
  out_local = self.dynamic_proj(out_local)
192
 
193
+ # ------------------------------
194
+ # 3) Global context via attention pooling (scale 제거)
195
+ # ------------------------------
196
  h = self.hyper(x_proj)
197
+ scores = tf.nn.softmax(self.attn_pool(h), axis=1) # (B, L, 1)
198
+ global_context = tf.reduce_sum(h * scores, axis=1) # (B, D)
199
+ # token-wise concat
200
+ global_context_exp = tf.expand_dims(global_context, 1) * tf.ones([B, L, 1], dtype=x_proj.dtype)
201
+ out_local = tf.concat([out_local, global_context_exp], axis=-1)
202
+ out_local = self.dense(out_local) # dimension 맞춤
203
+
204
+ # ------------------------------
205
+ # 4) Residual + SiLU + LayerNorm
206
+ # ------------------------------
 
207
  out = x_proj + out_local
208
  out = tf.nn.silu(out)
209
  out = self.norm(out)