Update DiT.py
Browse files
DiT.py
CHANGED
@@ -55,13 +55,18 @@ class TimestepEmbedder:
|
|
55 |
return t_emb
|
56 |
|
57 |
|
58 |
-
class LabelEmbedder:
|
59 |
"""
|
60 |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
61 |
"""
|
62 |
def __init__(self, num_classes, hidden_size, dropout_prob):
|
63 |
use_cfg_embedding = dropout_prob > 0
|
64 |
-
self.embedding_table =
|
|
|
|
|
|
|
|
|
|
|
65 |
self.num_classes = num_classes
|
66 |
self.dropout_prob = dropout_prob
|
67 |
|
@@ -156,7 +161,12 @@ class DiT(Model):
|
|
156 |
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
157 |
num_patches = self.x_embedder.num_patches
|
158 |
# Will use fixed sin-cos embedding:
|
159 |
-
self.pos_embed =
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
self.blocks = [
|
162 |
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
@@ -167,8 +177,7 @@ class DiT(Model):
|
|
167 |
def initialize_weights(self):
|
168 |
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
169 |
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
170 |
-
self.pos_embed
|
171 |
-
tf.Variable(self.pos_embed)
|
172 |
|
173 |
def unpatchify(self, x):
|
174 |
"""
|
|
|
55 |
return t_emb
|
56 |
|
57 |
|
58 |
+
class LabelEmbedder(tf.keras.layers.Layer):
|
59 |
"""
|
60 |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
61 |
"""
|
62 |
def __init__(self, num_classes, hidden_size, dropout_prob):
|
63 |
use_cfg_embedding = dropout_prob > 0
|
64 |
+
self.embedding_table = self.add_weight(
|
65 |
+
name='embedding_table',
|
66 |
+
shape=(num_classes + use_cfg_embedding, hidden_size),
|
67 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
|
68 |
+
trainable=True
|
69 |
+
)
|
70 |
self.num_classes = num_classes
|
71 |
self.dropout_prob = dropout_prob
|
72 |
|
|
|
161 |
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
162 |
num_patches = self.x_embedder.num_patches
|
163 |
# Will use fixed sin-cos embedding:
|
164 |
+
self.pos_embed = self.add_weight(
|
165 |
+
name='pos_embed',
|
166 |
+
shape=(1, num_patches, hidden_size),
|
167 |
+
initializer=tf.keras.initializers.Zeros(),
|
168 |
+
trainable=False # To freeze this variable
|
169 |
+
)
|
170 |
|
171 |
self.blocks = [
|
172 |
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
|
|
177 |
def initialize_weights(self):
|
178 |
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
179 |
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
180 |
+
self.pos_embed.assign(tf.convert_to_tensor(pos_embed, dtype=tf.float32)[tf.newaxis, :])
|
|
|
181 |
|
182 |
def unpatchify(self, x):
|
183 |
"""
|