| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| from einops import rearrange |
| from torch import Tensor, nn |
|
|
| try: |
| from apex.contrib.group_norm import GroupNorm |
|
|
| OPT_GROUP_NORM = True |
| except Exception: |
| print('Fused optimized group norm has not been installed.') |
| OPT_GROUP_NORM = False |
|
|
|
|
| |
| def Normalize(in_channels, num_groups=32, act=""): |
| """Creates a group normalization layer with specified activation. |
| |
| Args: |
| in_channels (int): Number of channels in the input. |
| num_groups (int, optional): Number of groups for GroupNorm. Defaults to 32. |
| act (str, optional): Activation function name. Defaults to "". |
| |
| Returns: |
| GroupNorm: A normalization layer with optional activation. |
| """ |
| return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act) |
|
|
|
|
| def nonlinearity(x): |
| """Nonlinearity function used in temporal embedding projection. |
| |
| Currently implemented as a SiLU (Swish) function. |
| |
| Args: |
| x (Tensor): Input tensor. |
| |
| Returns: |
| Tensor: Output after applying SiLU activation. |
| """ |
| return x * torch.sigmoid(x) |
|
|
|
|
| class ResnetBlock(nn.Module): |
| """A ResNet-style block that can optionally apply a temporal embedding and shortcut projections. |
| |
| This block consists of two convolutional layers, normalization, and optional temporal embedding. |
| It can adjust channel dimensions between input and output via shortcuts. |
| """ |
|
|
| def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, temb_channels=0): |
| """ |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int, optional): Number of output channels. Defaults to in_channels. |
| conv_shortcut (bool, optional): Whether to use a convolutional shortcut. Defaults to False. |
| dropout (float, optional): Dropout probability. Defaults to 0.0. |
| temb_channels (int, optional): Number of channels in temporal embedding. Defaults to 0. |
| """ |
| super().__init__() |
| self.in_channels = in_channels |
| out_channels = in_channels if out_channels is None else out_channels |
| self.out_channels = out_channels |
| self.use_conv_shortcut = conv_shortcut |
|
|
| self.norm1 = Normalize(in_channels, act="silu") |
| self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| if temb_channels > 0: |
| self.temb_proj = torch.nn.Linear(temb_channels, out_channels) |
| self.norm2 = Normalize(out_channels, act="silu") |
| self.dropout = torch.nn.Dropout(dropout) |
| self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| if self.in_channels != self.out_channels: |
| if self.use_conv_shortcut: |
| self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| else: |
| self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
|
|
| def forward(self, x, temb): |
| """Forward pass of the ResnetBlock. |
| |
| Args: |
| x (Tensor): Input feature map of shape (B, C, H, W). |
| temb (Tensor): Temporal embedding tensor of shape (B, temb_channels). |
| |
| Returns: |
| Tensor: Output feature map of shape (B, out_channels, H, W). |
| """ |
| h = x |
| h = self.norm1(h) |
| h = self.conv1(h) |
|
|
| if temb is not None: |
| h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] |
|
|
| h = self.norm2(h) |
| h = self.dropout(h) |
| h = self.conv2(h) |
|
|
| if self.in_channels != self.out_channels: |
| if self.use_conv_shortcut: |
| x = self.conv_shortcut(x) |
| else: |
| x = self.nin_shortcut(x) |
|
|
| return x + h |
|
|
|
|
| class Upsample(nn.Module): |
| """Upsampling block that increases spatial resolution by a factor of 2. |
| |
| Can optionally include a convolution after upsampling. |
| """ |
|
|
| def __init__(self, in_channels, with_conv): |
| """ |
| Args: |
| in_channels (int): Number of input channels. |
| with_conv (bool): If True, apply a convolution after upsampling. |
| """ |
| super().__init__() |
| self.with_conv = with_conv |
| if self.with_conv: |
| self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) |
|
|
| def forward(self, x): |
| """Forward pass of the Upsample block. |
| |
| Args: |
| x (Tensor): Input feature map (B, C, H, W). |
| |
| Returns: |
| Tensor: Upsampled feature map (B, C, 2H, 2W). |
| """ |
| |
| dtype = x.dtype |
| if dtype == torch.bfloat16: |
| x = x.to(torch.float32) |
| x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
| if dtype == torch.bfloat16: |
| x = x.to(dtype) |
| if self.with_conv: |
| x = self.conv(x) |
| return x |
|
|
|
|
| class Downsample(nn.Module): |
| """Downsampling block that reduces spatial resolution by a factor of 2. |
| |
| Can optionally include a convolution before downsampling. |
| """ |
|
|
| def __init__(self, in_channels, with_conv): |
| """ |
| Args: |
| in_channels (int): Number of input channels. |
| with_conv (bool): If True, apply a convolution before downsampling. |
| """ |
| super().__init__() |
| self.with_conv = with_conv |
| if self.with_conv: |
| |
| self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) |
|
|
| def forward(self, x): |
| """Forward pass of the Downsample block. |
| |
| Args: |
| x (Tensor): Input feature map (B, C, H, W). |
| |
| Returns: |
| Tensor: Downsampled feature map (B, C, H/2, W/2). |
| """ |
| if self.with_conv: |
| pad = (0, 1, 0, 1) |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
| x = self.conv(x) |
| else: |
| x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
| return x |
|
|
|
|
| class AttnBlock(nn.Module): |
| """Self-attention block that applies scaled dot-product attention to feature maps. |
| |
| Normalizes input, computes queries, keys, and values, then applies attention and a projection. |
| """ |
|
|
| def __init__(self, in_channels: int): |
| """ |
| Args: |
| in_channels (int): Number of input/output channels. |
| """ |
| super().__init__() |
| self.in_channels = in_channels |
|
|
| self.norm = Normalize(in_channels, act="silu") |
|
|
| self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) |
| self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) |
| self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) |
| self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) |
|
|
| def attention(self, h_: Tensor) -> Tensor: |
| """Compute the attention over the input feature maps. |
| |
| Args: |
| h_ (Tensor): Normalized input feature map (B, C, H, W). |
| |
| Returns: |
| Tensor: Output after applying scaled dot-product attention (B, C, H, W). |
| """ |
| h_ = self.norm(h_) |
| q = self.q(h_) |
| k = self.k(h_) |
| v = self.v(h_) |
|
|
| b, c, h, w = q.shape |
| q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() |
| k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() |
| v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() |
| h_ = nn.functional.scaled_dot_product_attention(q, k, v) |
|
|
| return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| """Forward pass of the AttnBlock. |
| |
| Args: |
| x (Tensor): Input feature map (B, C, H, W). |
| |
| Returns: |
| Tensor: Output feature map after self-attention (B, C, H, W). |
| """ |
| return x + self.proj_out(self.attention(x)) |
|
|
|
|
| class LinearAttention(nn.Module): |
| """Linear Attention block for efficient attention computations. |
| |
| Uses linear attention mechanisms to reduce complexity and memory usage. |
| """ |
|
|
| def __init__(self, dim, heads=4, dim_head=32): |
| """ |
| Args: |
| dim (int): Input channel dimension. |
| heads (int, optional): Number of attention heads. Defaults to 4. |
| dim_head (int, optional): Dimension per attention head. Defaults to 32. |
| """ |
| super().__init__() |
| self.heads = heads |
| hidden_dim = dim_head * heads |
| self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) |
| self.to_out = nn.Conv2d(hidden_dim, dim, 1) |
|
|
| def forward(self, x): |
| """Forward pass of the LinearAttention block. |
| |
| Args: |
| x (Tensor): Input feature map (B, C, H, W). |
| |
| Returns: |
| Tensor: Output feature map after linear attention (B, C, H, W). |
| """ |
| b, c, h, w = x.shape |
| qkv = self.to_qkv(x) |
| q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) |
| k = k.softmax(dim=-1) |
| context = torch.einsum('bhdn,bhen->bhde', k, v) |
| out = torch.einsum('bhde,bhdn->bhen', context, q) |
| out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) |
| return self.to_out(out) |
|
|
|
|
| class LinAttnBlock(LinearAttention): |
| """Wrapper class to provide a linear attention block in a form compatible with other attention blocks.""" |
|
|
| def __init__(self, in_channels): |
| """ |
| Args: |
| in_channels (int): Number of input/output channels. |
| """ |
| super().__init__(dim=in_channels, heads=1, dim_head=in_channels) |
|
|
|
|
| def make_attn(in_channels, attn_type="vanilla"): |
| """Factory function to create an attention block. |
| |
| Args: |
| in_channels (int): Number of input/output channels. |
| attn_type (str, optional): Type of attention block to create. Options: "vanilla", "linear", "none". |
| Defaults to "vanilla". |
| |
| Returns: |
| nn.Module: An instance of the requested attention block. |
| """ |
| assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' |
| print(f"making attention of type '{attn_type}' with {in_channels} in_channels") |
| if attn_type == "vanilla": |
| return AttnBlock(in_channels) |
| elif attn_type == "none": |
| return nn.Identity(in_channels) |
| else: |
| return LinAttnBlock(in_channels) |
|
|
|
|
| |
|
|