Update V2.py
Browse files
V2.py
CHANGED
|
@@ -138,10 +138,12 @@ class HyperConv1D(layers.Layer):
|
|
| 138 |
# Input projection
|
| 139 |
self.input_proj = layers.Dense(d_model, name="input_proj")
|
| 140 |
|
| 141 |
-
# Dynamic kernel conv
|
| 142 |
-
self.
|
|
|
|
| 143 |
self.dynamic_proj = layers.Dense(d_model)
|
| 144 |
self.kernel_generator = layers.Dense(k, dtype='float32')
|
|
|
|
| 145 |
|
| 146 |
# Hypernetwork: token-wise transform before pooling
|
| 147 |
self.hyper = tf.keras.Sequential([
|
|
@@ -150,13 +152,13 @@ class HyperConv1D(layers.Layer):
|
|
| 150 |
], name="hyper")
|
| 151 |
|
| 152 |
self.attn_pool = layers.Dense(1)
|
| 153 |
-
self.scale_dense = layers.Dense(d_model)
|
| 154 |
|
|
|
|
| 155 |
self.norm = layers.LayerNormalization()
|
| 156 |
self.dropout = layers.Dropout(dropout)
|
| 157 |
|
| 158 |
def call(self, x, training=None):
|
| 159 |
-
x_in = x
|
| 160 |
x_dtype = x.dtype
|
| 161 |
|
| 162 |
# 1) input projection
|
|
@@ -170,8 +172,9 @@ class HyperConv1D(layers.Layer):
|
|
| 170 |
# ------------------------------
|
| 171 |
# 2) DynamicConv local mixing
|
| 172 |
# ------------------------------
|
| 173 |
-
kernels = self.kernel_generator(self.dynamic_dense(x_proj)) # (B,
|
| 174 |
-
kernels = tf.
|
|
|
|
| 175 |
|
| 176 |
x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
|
| 177 |
x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
|
|
@@ -183,7 +186,7 @@ class HyperConv1D(layers.Layer):
|
|
| 183 |
padding='VALID'
|
| 184 |
)
|
| 185 |
patches = tf.reshape(patches, [B, L, self.k, D])
|
| 186 |
-
kernels_exp = tf.expand_dims(kernels, axis=-1)
|
| 187 |
out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
|
| 188 |
out_local = self.dynamic_proj(out_local)
|
| 189 |
|
|
@@ -195,11 +198,12 @@ class HyperConv1D(layers.Layer):
|
|
| 195 |
global_z = tf.nn.softmax(global_z, axis=1)
|
| 196 |
global_z = tf.reduce_sum(h * global_z, axis=1)
|
| 197 |
|
| 198 |
-
scale
|
| 199 |
-
|
|
|
|
| 200 |
|
| 201 |
# ------------------------------
|
| 202 |
-
# 4) Residual + SiLU + LayerNorm
|
| 203 |
# ------------------------------
|
| 204 |
out = x_proj + out_local
|
| 205 |
out = tf.nn.silu(out)
|
|
|
|
| 138 |
# Input projection
|
| 139 |
self.input_proj = layers.Dense(d_model, name="input_proj")
|
| 140 |
|
| 141 |
+
# Dynamic kernel conv (중간 차원)
|
| 142 |
+
self.d_mid = max(64, d_model // 8)
|
| 143 |
+
self.dynamic_dense = layers.Dense(self.d_mid, activation='silu')
|
| 144 |
self.dynamic_proj = layers.Dense(d_model)
|
| 145 |
self.kernel_generator = layers.Dense(k, dtype='float32')
|
| 146 |
+
self.kernel_temp = self.add_weight("kernel_temp", shape=(), initializer=tf.constant_initializer(1.0), trainable=True)
|
| 147 |
|
| 148 |
# Hypernetwork: token-wise transform before pooling
|
| 149 |
self.hyper = tf.keras.Sequential([
|
|
|
|
| 152 |
], name="hyper")
|
| 153 |
|
| 154 |
self.attn_pool = layers.Dense(1)
|
| 155 |
+
self.scale_dense = layers.Dense(d_model, bias_initializer=tf.keras.initializers.Constant(0.0))
|
| 156 |
|
| 157 |
+
# Normalization + dropout
|
| 158 |
self.norm = layers.LayerNormalization()
|
| 159 |
self.dropout = layers.Dropout(dropout)
|
| 160 |
|
| 161 |
def call(self, x, training=None):
|
|
|
|
| 162 |
x_dtype = x.dtype
|
| 163 |
|
| 164 |
# 1) input projection
|
|
|
|
| 172 |
# ------------------------------
|
| 173 |
# 2) DynamicConv local mixing
|
| 174 |
# ------------------------------
|
| 175 |
+
kernels = self.kernel_generator(self.dynamic_dense(x_proj)) # (B,L,k)
|
| 176 |
+
kernels = tf.cast(kernels, tf.float32)
|
| 177 |
+
kernels = tf.nn.softmax(kernels / tf.maximum(self.kernel_temp, 1e-6), axis=-1)
|
| 178 |
|
| 179 |
x_pad = tf.pad(x_proj, [[0,0],[pad,pad],[0,0]])
|
| 180 |
x_pad_4d = tf.expand_dims(x_pad, axis=1) # (B,1,L+k-1,D)
|
|
|
|
| 186 |
padding='VALID'
|
| 187 |
)
|
| 188 |
patches = tf.reshape(patches, [B, L, self.k, D])
|
| 189 |
+
kernels_exp = tf.cast(tf.expand_dims(kernels, axis=-1), x_proj.dtype)
|
| 190 |
out_local = tf.reduce_sum(patches * kernels_exp, axis=2) # (B,L,D)
|
| 191 |
out_local = self.dynamic_proj(out_local)
|
| 192 |
|
|
|
|
| 198 |
global_z = tf.nn.softmax(global_z, axis=1)
|
| 199 |
global_z = tf.reduce_sum(h * global_z, axis=1)
|
| 200 |
|
| 201 |
+
# residual-gate 스타일 scale: 1 + α*tanh(...)
|
| 202 |
+
scale = 1.0 + 0.1 * tf.tanh(self.scale_dense(global_z))
|
| 203 |
+
out_local = out_local * tf.expand_dims(scale, 1)
|
| 204 |
|
| 205 |
# ------------------------------
|
| 206 |
+
# 4) Residual + SiLU + LayerNorm + Dropout
|
| 207 |
# ------------------------------
|
| 208 |
out = x_proj + out_local
|
| 209 |
out = tf.nn.silu(out)
|