| """AdaLN-Zero modules for shared-base + low-rank-delta conditioning.""" | |
| from __future__ import annotations | |
| from torch import Tensor, nn | |
| class AdaLNZeroProjector(nn.Module): | |
| """Shared base AdaLN projection: SiLU -> Linear(d_cond -> 4*d_model). | |
| Returns packed modulation tensor [B, 4*d_model]. Zero-initialized. | |
| """ | |
| def __init__(self, d_model: int, d_cond: int) -> None: | |
| super().__init__() | |
| self.d_model = int(d_model) | |
| self.d_cond = int(d_cond) | |
| self.act = nn.SiLU() | |
| self.proj = nn.Linear(self.d_cond, 4 * self.d_model) | |
| nn.init.zeros_(self.proj.weight) | |
| nn.init.zeros_(self.proj.bias) | |
| def forward(self, cond: Tensor) -> Tensor: | |
| """Return packed modulation [B, 4*d_model] from conditioning [B, d_cond].""" | |
| act = self.act(cond) | |
| return self.proj(act) | |
| def forward_activated(self, act_cond: Tensor) -> Tensor: | |
| """Return packed modulation from pre-activated conditioning.""" | |
| return self.proj(act_cond) | |
| class AdaLNZeroLowRankDelta(nn.Module): | |
| """Per-layer low-rank delta: down(d_cond -> rank) -> up(rank -> 4*d_model). | |
| Zero-initialized up-projection preserves AdaLN "zero output" at init. | |
| """ | |
| def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None: | |
| super().__init__() | |
| self.d_model = int(d_model) | |
| self.d_cond = int(d_cond) | |
| self.rank = int(rank) | |
| self.down = nn.Linear(self.d_cond, self.rank, bias=False) | |
| self.up = nn.Linear(self.rank, 4 * self.d_model, bias=False) | |
| nn.init.normal_(self.down.weight, mean=0.0, std=0.02) | |
| nn.init.zeros_(self.up.weight) | |
| def forward(self, act_cond: Tensor) -> Tensor: | |
| """Return packed delta modulation [B, 4*d_model] from activated cond.""" | |
| return self.up(self.down(act_cond)) | |