| | |
| | |
| | |
| | from functools import partial |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| |
|
| |
|
| | class BLCModuleCompatibleBCHW(nn.Module): |
| | def forward_blc(self, x): |
| | raise NotImplementedError() |
| |
|
| | def forward(self, x): |
| | is2d = x.ndim == 4 |
| | if is2d: |
| | _, _, H, W = x.shape |
| | x = rearrange(x, "B C H W -> B (H W) C") |
| |
|
| | x = self.forward_blc(x) |
| |
|
| | if is2d: |
| | x = rearrange(x, "B (H W) C -> B C H W", H=H, W=W) |
| |
|
| | return x |
| |
|
| |
|
| | class FeatureEncoder(nn.Module): |
| | """Encoder + Feature extractor |
| | """ |
| | def __init__(self, safe=True): |
| | super().__init__() |
| | self.safe = safe |
| | self._features = [] |
| |
|
| | def hook(self, module, input, output): |
| | self._features.append(output) |
| |
|
| | def clear_features(self): |
| | self._features.clear() |
| |
|
| | def _encode(self, x): |
| | raise NotImplementedError() |
| |
|
| | def forward(self, *args, ret_feats=False, **kwargs): |
| | self.clear_features() |
| |
|
| | x = self._encode(*args, **kwargs) |
| |
|
| | if ret_feats: |
| | if self.safe: |
| | features = [t.clone() for t in self._features] |
| | self.clear_features() |
| | else: |
| | features = self._features |
| | return x, features |
| | else: |
| | self.clear_features() |
| | return x |
| |
|
| |
|
| | class Project2d(nn.Module): |
| | """2d projection by 1x1 conv |
| | |
| | Args: |
| | p: [C_in, C_out] |
| | """ |
| | def __init__(self, p): |
| | |
| | super().__init__() |
| | p = rearrange(p, "Cin Cout -> Cout Cin 1 1") |
| | self.p = nn.Parameter(p.detach().clone()) |
| |
|
| | def forward(self, x): |
| | return F.conv2d(x, self.p) |
| |
|
| |
|
| | def dispatcher(dispatch_fn): |
| | def decorated(key, *args): |
| | if callable(key): |
| | return key |
| |
|
| | if key is None: |
| | key = "none" |
| |
|
| | return dispatch_fn(key, *args) |
| |
|
| | return decorated |
| |
|
| |
|
| | @dispatcher |
| | def activ_dispatch(activ): |
| | return { |
| | "none": nn.Identity, |
| | "relu": nn.ReLU, |
| | "lrelu": partial(nn.LeakyReLU, negative_slope=0.2), |
| | "gelu": nn.GELU, |
| | }[activ.lower()] |
| |
|
| |
|
| | def get_norm_fn(norm, C): |
| | """2d normalization layers |
| | """ |
| | if norm is None or norm == "none": |
| | return nn.Identity() |
| |
|
| | return { |
| | "bn": nn.BatchNorm2d(C), |
| | "syncbn": nn.SyncBatchNorm(C), |
| | "ln": LayerNorm2d(C), |
| | "gn": nn.GroupNorm(32, C), |
| | }[norm] |
| |
|
| |
|
| | class LayerNorm2d(nn.LayerNorm): |
| | def __init__(self, num_channels, eps=1e-5, affine=True): |
| | super().__init__(num_channels, eps=eps, elementwise_affine=affine) |
| |
|
| | def forward(self, x): |
| | return F.layer_norm( |
| | x.permute(0, 2, 3, 1), |
| | self.normalized_shape, |
| | self.weight, |
| | self.bias, |
| | self.eps |
| | ).permute(0, 3, 1, 2) |
| |
|
| |
|
| | class Gate(nn.Module): |
| | """Tanh gate""" |
| | def __init__(self, init=0.0): |
| | super().__init__() |
| | self.gate = nn.Parameter(torch.as_tensor(init)) |
| |
|
| | def forward(self, x): |
| | return torch.tanh(self.gate) * x |
| |
|
| |
|
| | class ConvBlock(nn.Module): |
| | def __init__( |
| | self, |
| | C_in, |
| | C_out, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | norm="none", |
| | activ="relu", |
| | bias=True, |
| | upsample=False, |
| | downsample=False, |
| | pad_type="zeros", |
| | dropout=0.0, |
| | gate=False, |
| | ): |
| | super().__init__() |
| | if kernel_size == 1: |
| | assert padding == 0 |
| | self.C_in = C_in |
| | self.C_out = C_out |
| |
|
| | activ = activ_dispatch(activ) |
| | self.upsample = upsample |
| | self.downsample = downsample |
| |
|
| | self.norm = get_norm_fn(norm, C_in) |
| | self.activ = activ() |
| | if dropout > 0.0: |
| | self.dropout = nn.Dropout2d(p=dropout) |
| | self.conv = nn.Conv2d( |
| | C_in, C_out, kernel_size, stride, padding, |
| | bias=bias, padding_mode=pad_type |
| | ) |
| |
|
| | self.gate = Gate() if gate else None |
| |
|
| | def forward(self, x): |
| | |
| | x = self.norm(x) |
| | x = self.activ(x) |
| | if self.upsample: |
| | x = F.interpolate(x, scale_factor=2) |
| | if hasattr(self, "dropout"): |
| | x = self.dropout(x) |
| | x = self.conv(x) |
| | if self.downsample: |
| | x = F.avg_pool2d(x, 2) |
| |
|
| | if self.gate is not None: |
| | x = self.gate(x) |
| |
|
| | return x |
| |
|
| |
|
| | class ResConv(nn.Module): |
| | """Pre-activate residual block with single or double conv block""" |
| |
|
| | def __init__( |
| | self, |
| | C_in, |
| | C_out, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | norm="none", |
| | activ="relu", |
| | upsample=False, |
| | pad_type="zeros", |
| | dropout=0.0, |
| | gate=True, |
| | double=False, |
| | |
| | norm2=None, |
| | activ2=None |
| | ): |
| | super().__init__() |
| |
|
| | self.C_in = C_in |
| | self.C_out = C_out |
| | self.upsample = upsample |
| | self.double = double |
| | self.conv = ConvBlock( |
| | C_in, C_out, kernel_size, stride, padding, norm, activ, |
| | pad_type=pad_type, dropout=dropout, gate=gate, |
| | ) |
| | if double: |
| | norm2 = norm2 or norm |
| | activ2 = activ2 or activ |
| | self.conv2 = ConvBlock( |
| | C_out, C_out, kernel_size, stride, padding, norm2, activ2, |
| | pad_type=pad_type, dropout=dropout, gate=gate |
| | ) |
| |
|
| | def forward(self, x): |
| | if self.upsample: |
| | x = F.interpolate(x, scale_factor=2) |
| | x = x + self.conv(x) |
| |
|
| | if self.double: |
| | x = x + self.conv2(x) |
| |
|
| | return x |
| |
|