multimodalart HF Staff commited on
Commit
5d89594
·
verified ·
1 Parent(s): 93eee8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -144,15 +144,16 @@ class DDiTBlock(nn.Module):
144
  self.attn = SelfAttention(config)
145
  self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
146
  self.mlp = MLP(config)
147
- self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd, bias=True)
 
148
  self.adaLN_modulation.weight.data.zero_()
149
  self.adaLN_modulation.bias.data.zero_()
 
150
  def forward(self, x, c):
151
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
152
- x_skip = x
153
- x = modulate(self.ln_1(x), shift_msa, scale_msa)
154
- x = self.attn(x)
155
- x = bias_add_scale(x, None, gate_msa, x_skip)
156
  x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
157
  return x
158
 
 
144
  self.attn = SelfAttention(config)
145
  self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
146
  self.mlp = MLP(config)
147
+
148
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
149
  self.adaLN_modulation.weight.data.zero_()
150
  self.adaLN_modulation.bias.data.zero_()
151
+
152
  def forward(self, x, c):
153
  shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
154
+ x_skip = x
155
+ modulated_x = modulate(self.ln_1(x), shift_msa, scale_msa)
156
+ x = bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
 
157
  x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
158
  return x
159