NoteDance commited on
Commit
efa1c71
1 Parent(s): d4e672c

Update DiT.py

Browse files
Files changed (1) hide show
  1. DiT.py +14 -5
DiT.py CHANGED
@@ -55,13 +55,18 @@ class TimestepEmbedder:
55
  return t_emb
56
 
57
 
58
- class LabelEmbedder:
59
  """
60
  Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
61
  """
62
  def __init__(self, num_classes, hidden_size, dropout_prob):
63
  use_cfg_embedding = dropout_prob > 0
64
- self.embedding_table = tf.Variable(tf.random.normal((num_classes + use_cfg_embedding, hidden_size), stddev=0.02))
 
 
 
 
 
65
  self.num_classes = num_classes
66
  self.dropout_prob = dropout_prob
67
 
@@ -156,7 +161,12 @@ class DiT(Model):
156
  self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
157
  num_patches = self.x_embedder.num_patches
158
  # Will use fixed sin-cos embedding:
159
- self.pos_embed = tf.zeros((1, num_patches, hidden_size))
 
 
 
 
 
160
 
161
  self.blocks = [
162
  DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
@@ -167,8 +177,7 @@ class DiT(Model):
167
  def initialize_weights(self):
168
  # Initialize (and freeze) pos_embed by sin-cos embedding:
169
  pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
170
- self.pos_embed = tf.convert_to_tensor(pos_embed, dtype=tf.float32)[tf.newaxis, :]
171
- tf.Variable(self.pos_embed)
172
 
173
  def unpatchify(self, x):
174
  """
 
55
  return t_emb
56
 
57
 
58
+ class LabelEmbedder(tf.keras.layers.Layer):
59
  """
60
  Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
61
  """
62
  def __init__(self, num_classes, hidden_size, dropout_prob):
63
  use_cfg_embedding = dropout_prob > 0
64
+ self.embedding_table = self.add_weight(
65
+ name='embedding_table',
66
+ shape=(num_classes + use_cfg_embedding, hidden_size),
67
+ initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
68
+ trainable=True
69
+ )
70
  self.num_classes = num_classes
71
  self.dropout_prob = dropout_prob
72
 
 
161
  self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
162
  num_patches = self.x_embedder.num_patches
163
  # Will use fixed sin-cos embedding:
164
+ self.pos_embed = self.add_weight(
165
+ name='pos_embed',
166
+ shape=(1, num_patches, hidden_size),
167
+ initializer=tf.keras.initializers.Zeros(),
168
+ trainable=False # To freeze this variable
169
+ )
170
 
171
  self.blocks = [
172
  DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
 
177
  def initialize_weights(self):
178
  # Initialize (and freeze) pos_embed by sin-cos embedding:
179
  pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
180
+ self.pos_embed.assign(tf.convert_to_tensor(pos_embed, dtype=tf.float32)[tf.newaxis, :])
 
181
 
182
  def unpatchify(self, x):
183
  """