# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from typing import Type class MLPBlock(nn.Module): def __init__( self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU, ) -> None: super().__init__() self.lin1 = nn.Linear(embedding_dim, mlp_dim) self.lin2 = nn.Linear(mlp_dim, embedding_dim) self.act = act() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.lin2(self.act(self.lin1(x))) # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa class LayerNorm2d(nn.Module): def __init__(self, num_channels: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x def val2list(x: list or tuple or any, repeat_time=1) -> list: if isinstance(x, (list, tuple)): return list(x) return [x for _ in range(repeat_time)] def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: x = val2list(x) # repeat elements if necessary if len(x) > 0: x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] return tuple(x) def list_sum(x: list) -> any: return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) def resize( x: torch.Tensor, size: any or None = None, scale_factor=None, mode: str = "bicubic", align_corners: bool or None = False, ) -> torch.Tensor: if mode in ["bilinear", "bicubic"]: return F.interpolate( x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners, ) elif mode in ["nearest", "area"]: return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) else: raise NotImplementedError(f"resize(mode={mode}) not implemented.") class UpSampleLayer(nn.Module): def __init__( self, mode="bicubic", size=None, factor=2, align_corners=False, ): super(UpSampleLayer, self).__init__() self.mode = mode self.size = val2list(size, 2) if size is not None else None self.factor = None if self.size is not None else factor self.align_corners = align_corners def forward(self, x: torch.Tensor) -> torch.Tensor: return resize(x, self.size, self.factor, self.mode, self.align_corners) class OpSequential(nn.Module): def __init__(self, op_list): super(OpSequential, self).__init__() valid_op_list = [] for op in op_list: if op is not None: valid_op_list.append(op) self.op_list = nn.ModuleList(valid_op_list) def forward(self, x: torch.Tensor) -> torch.Tensor: for op in self.op_list: x = op(x) return x