Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -144,15 +144,16 @@ class DDiTBlock(nn.Module):
|
|
| 144 |
self.attn = SelfAttention(config)
|
| 145 |
self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
|
| 146 |
self.mlp = MLP(config)
|
| 147 |
-
|
|
|
|
| 148 |
self.adaLN_modulation.weight.data.zero_()
|
| 149 |
self.adaLN_modulation.bias.data.zero_()
|
|
|
|
| 150 |
def forward(self, x, c):
|
| 151 |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
|
| 152 |
-
x_skip = x
|
| 153 |
-
|
| 154 |
-
x = self.attn(x)
|
| 155 |
-
x = bias_add_scale(x, None, gate_msa, x_skip)
|
| 156 |
x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
|
| 157 |
return x
|
| 158 |
|
|
|
|
| 144 |
self.attn = SelfAttention(config)
|
| 145 |
self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
|
| 146 |
self.mlp = MLP(config)
|
| 147 |
+
|
| 148 |
+
self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd)
|
| 149 |
self.adaLN_modulation.weight.data.zero_()
|
| 150 |
self.adaLN_modulation.bias.data.zero_()
|
| 151 |
+
|
| 152 |
def forward(self, x, c):
|
| 153 |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
|
| 154 |
+
x_skip = x
|
| 155 |
+
modulated_x = modulate(self.ln_1(x), shift_msa, scale_msa)
|
| 156 |
+
x = bias_add_scale(self.attn(self.ln_1(x)), None, gate_msa, x_skip)
|
|
|
|
| 157 |
x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
|
| 158 |
return x
|
| 159 |
|