| | """DiCo block: conv path (1x1 -> depthwise -> SiLU -> CCA -> 1x1) + GELU MLP.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import Tensor, nn |
| |
|
| | from .compact_channel_attention import CompactChannelAttention |
| | from .conv_mlp import ConvMLP |
| | from .norms import ChannelWiseRMSNorm |
| |
|
| |
|
| | class DiCoBlock(nn.Module): |
| | """DiCo-style conv block with optional external AdaLN conditioning. |
| | |
| | Two modes: |
| | - Unconditioned (encoder): uses learned per-channel residual gates. |
| | - External AdaLN (decoder): receives packed modulation [B, 4*C] via adaln_m. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | channels: int, |
| | mlp_ratio: float, |
| | *, |
| | depthwise_kernel_size: int = 7, |
| | use_external_adaln: bool = False, |
| | norm_eps: float = 1e-6, |
| | ) -> None: |
| | super().__init__() |
| | self.channels = int(channels) |
| | self.use_external_adaln = bool(use_external_adaln) |
| |
|
| | |
| | self.norm1 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False) |
| | self.norm2 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False) |
| |
|
| | |
| | self.conv1 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True) |
| | self.conv2 = nn.Conv2d( |
| | self.channels, |
| | self.channels, |
| | kernel_size=depthwise_kernel_size, |
| | padding=depthwise_kernel_size // 2, |
| | groups=self.channels, |
| | bias=True, |
| | ) |
| | self.conv3 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True) |
| | self.cca = CompactChannelAttention(self.channels) |
| |
|
| | |
| | hidden_channels = max(int(round(float(self.channels) * mlp_ratio)), 1) |
| | self.mlp = ConvMLP(self.channels, hidden_channels, norm_eps=norm_eps) |
| |
|
| | |
| | if not self.use_external_adaln: |
| | self.gate_attn = nn.Parameter(torch.zeros(self.channels)) |
| | self.gate_mlp = nn.Parameter(torch.zeros(self.channels)) |
| |
|
| | def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor: |
| | b, c = x.shape[:2] |
| |
|
| | if self.use_external_adaln: |
| | if adaln_m is None: |
| | raise ValueError( |
| | "adaln_m required for externally-conditioned DiCoBlock" |
| | ) |
| | adaln_m_cast = adaln_m.to(device=x.device, dtype=x.dtype) |
| | scale_a, gate_a, scale_m, gate_m = adaln_m_cast.chunk(4, dim=-1) |
| | elif adaln_m is not None: |
| | raise ValueError("adaln_m must be None for unconditioned DiCoBlock") |
| |
|
| | residual = x |
| |
|
| | |
| | x_att = self.norm1(x) |
| | if self.use_external_adaln: |
| | x_att = x_att * (1.0 + scale_a.view(b, c, 1, 1)) |
| | y = self.conv1(x_att) |
| | y = self.conv2(y) |
| | y = F.silu(y) |
| | y = y * self.cca(y) |
| | y = self.conv3(y) |
| |
|
| | if self.use_external_adaln: |
| | gate_a_view = torch.tanh(gate_a).view(b, c, 1, 1) |
| | x = residual + gate_a_view * y |
| | else: |
| | gate = self.gate_attn.view(1, self.channels, 1, 1).to( |
| | dtype=y.dtype, device=y.device |
| | ) |
| | x = residual + gate * y |
| |
|
| | |
| | residual_mlp = x |
| | x_mlp = self.norm2(x) |
| | if self.use_external_adaln: |
| | x_mlp = x_mlp * (1.0 + scale_m.view(b, c, 1, 1)) |
| | y_mlp = self.mlp(x_mlp) |
| |
|
| | if self.use_external_adaln: |
| | gate_m_view = torch.tanh(gate_m).view(b, c, 1, 1) |
| | x = residual_mlp + gate_m_view * y_mlp |
| | else: |
| | gate = self.gate_mlp.view(1, self.channels, 1, 1).to( |
| | dtype=y_mlp.dtype, device=y_mlp.device |
| | ) |
| | x = residual_mlp + gate * y_mlp |
| |
|
| | return x |
| |
|