| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| import collections |
| from functools import partial |
| from itertools import repeat |
| from typing import Callable |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
|
|
| |
| def _ntuple(n): |
| def parse(x): |
| if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
| return tuple(x) |
| return tuple(repeat(x, n)) |
|
|
| return parse |
|
|
|
|
| def exists(val): |
| return val is not None |
|
|
|
|
| def default(val, d): |
| return val if exists(val) else d |
|
|
|
|
| to_2tuple = _ntuple(2) |
|
|
|
|
| class ResidualBlock(nn.Module): |
| """ |
| ResidualBlock: construct a block of two conv layers with residual connections |
| """ |
|
|
| def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): |
| super(ResidualBlock, self).__init__() |
|
|
| self.conv1 = nn.Conv2d( |
| in_planes, |
| planes, |
| kernel_size=kernel_size, |
| padding=1, |
| stride=stride, |
| padding_mode="zeros", |
| ) |
| self.conv2 = nn.Conv2d( |
| planes, |
| planes, |
| kernel_size=kernel_size, |
| padding=1, |
| padding_mode="zeros", |
| ) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| num_groups = planes // 8 |
|
|
| if norm_fn == "group": |
| self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
| if not stride == 1: |
| self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) |
|
|
| elif norm_fn == "batch": |
| self.norm1 = nn.BatchNorm2d(planes) |
| self.norm2 = nn.BatchNorm2d(planes) |
| if not stride == 1: |
| self.norm3 = nn.BatchNorm2d(planes) |
|
|
| elif norm_fn == "instance": |
| self.norm1 = nn.InstanceNorm2d(planes) |
| self.norm2 = nn.InstanceNorm2d(planes) |
| if not stride == 1: |
| self.norm3 = nn.InstanceNorm2d(planes) |
|
|
| elif norm_fn == "none": |
| self.norm1 = nn.Sequential() |
| self.norm2 = nn.Sequential() |
| if not stride == 1: |
| self.norm3 = nn.Sequential() |
| else: |
| raise NotImplementedError |
|
|
| if stride == 1: |
| self.downsample = None |
| else: |
| self.downsample = nn.Sequential( |
| nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), |
| self.norm3, |
| ) |
|
|
| def forward(self, x): |
| y = x |
| y = self.relu(self.norm1(self.conv1(y))) |
| y = self.relu(self.norm2(self.conv2(y))) |
|
|
| if self.downsample is not None: |
| x = self.downsample(x) |
|
|
| return self.relu(x + y) |
|
|
|
|
| class Mlp(nn.Module): |
| """MLP as used in Vision Transformer, MLP-Mixer and related networks""" |
|
|
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| norm_layer=None, |
| bias=True, |
| drop=0.0, |
| use_conv=False, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| bias = to_2tuple(bias) |
| drop_probs = to_2tuple(drop) |
| linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
|
| self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
| self.act = act_layer() |
| self.drop1 = nn.Dropout(drop_probs[0]) |
| self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
| self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop1(x) |
| x = self.fc2(x) |
| x = self.drop2(x) |
| return x |
|
|
|
|
| class AttnBlock(nn.Module): |
| def __init__( |
| self, |
| hidden_size, |
| num_heads, |
| attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, |
| mlp_ratio=4.0, |
| **block_kwargs |
| ): |
| """ |
| Self attention block |
| """ |
| super().__init__() |
|
|
| self.norm1 = nn.LayerNorm(hidden_size) |
| self.norm2 = nn.LayerNorm(hidden_size) |
|
|
| self.attn = attn_class( |
| embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs |
| ) |
|
|
| mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
|
| self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) |
|
|
| def forward(self, x, mask=None): |
| |
| |
| |
| x = self.norm1(x) |
|
|
| |
| |
|
|
| attn_output, _ = self.attn(x, x, x) |
|
|
| |
| x = x + attn_output |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class CrossAttnBlock(nn.Module): |
| def __init__( |
| self, |
| hidden_size, |
| context_dim, |
| num_heads=1, |
| mlp_ratio=4.0, |
| eps=1e-5, |
| **block_kwargs |
| ): |
| """ |
| Cross attention block |
| """ |
| super().__init__() |
|
|
| self.norm1 = nn.LayerNorm(hidden_size, eps=eps) |
| self.norm_context = nn.LayerNorm(context_dim, eps=eps) |
| self.norm2 = nn.LayerNorm(hidden_size, eps=eps) |
|
|
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=hidden_size, |
| kdim=context_dim, |
| vdim=context_dim, |
| num_heads=num_heads, |
| batch_first=True, |
| **block_kwargs |
| ) |
|
|
| mlp_hidden_dim = int(hidden_size * mlp_ratio) |
|
|
| self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) |
|
|
| def forward(self, x, context, mask=None): |
| |
| x = self.norm1(x) |
| context = self.norm_context(context) |
|
|
| |
| |
| attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) |
|
|
| |
| x = x + attn_output |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|