Create 이외.py
Browse files
이외.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
class MixerBlock(layers.Layer):
|
| 3 |
+
def __init__(self, seq_len, dim, token_mlp_dim, channel_mlp_dim, dropout=0.0):
|
| 4 |
+
super().__init__()
|
| 5 |
+
self.seq_len = seq_len
|
| 6 |
+
self.dim = dim
|
| 7 |
+
self.token_mlp_dim = token_mlp_dim
|
| 8 |
+
self.channel_mlp_dim = channel_mlp_dim
|
| 9 |
+
|
| 10 |
+
self.ln1 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
|
| 11 |
+
# token-mixing MLP: operate over tokens => apply Dense on transposed axis
|
| 12 |
+
self.token_fc1 = layers.Dense(token_mlp_dim, activation='gelu', dtype=tf.float32)
|
| 13 |
+
self.token_fc2 = layers.Dense(seq_len, dtype=tf.float32)
|
| 14 |
+
|
| 15 |
+
self.ln2 = layers.LayerNormalization(epsilon=1e-6, dtype=tf.float32)
|
| 16 |
+
# channel-mixing MLP: operate per-token over channels
|
| 17 |
+
self.channel_fc1 = layers.Dense(channel_mlp_dim, activation='gelu', dtype=tf.float32)
|
| 18 |
+
self.channel_fc2 = layers.Dense(dim, dtype=tf.float32)
|
| 19 |
+
|
| 20 |
+
self.dropout = layers.Dropout(dropout)
|
| 21 |
+
|
| 22 |
+
def call(self, x, training=None):
|
| 23 |
+
# x: (B, L, D)
|
| 24 |
+
B = tf.shape(x)[0]
|
| 25 |
+
L = tf.shape(x)[1]
|
| 26 |
+
D = tf.shape(x)[2]
|
| 27 |
+
|
| 28 |
+
# Token-mixing
|
| 29 |
+
y = self.ln1(x) # (B, L, D)
|
| 30 |
+
y_t = tf.transpose(y, perm=[0,2,1]) # (B, D, L)
|
| 31 |
+
y_t = self.token_fc1(y_t) # (B, D, token_mlp_dim)
|
| 32 |
+
y_t = self.token_fc2(y_t) # (B, D, L)
|
| 33 |
+
y = tf.transpose(y_t, perm=[0,2,1]) # (B, L, D)
|
| 34 |
+
x = x + self.dropout(y, training=training)
|
| 35 |
+
|
| 36 |
+
# Channel-mixing
|
| 37 |
+
z = self.ln2(x)
|
| 38 |
+
z = self.channel_fc1(z)
|
| 39 |
+
z = self.channel_fc2(z)
|
| 40 |
+
x = x + self.dropout(z, training=training)
|
| 41 |
+
|
| 42 |
+
return x
|