mart9992's picture
m
aede1d5
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Type
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def val2list(x: list or tuple or any, repeat_time=1) -> list:
if isinstance(x, (list, tuple)):
return list(x)
return [x for _ in range(repeat_time)]
def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:
x = val2list(x)
# repeat elements if necessary
if len(x) > 0:
x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
return tuple(x)
def list_sum(x: list) -> any:
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
def resize(
x: torch.Tensor,
size: any or None = None,
scale_factor=None,
mode: str = "bicubic",
align_corners: bool or None = False,
) -> torch.Tensor:
if mode in ["bilinear", "bicubic"]:
return F.interpolate(
x,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
)
elif mode in ["nearest", "area"]:
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
else:
raise NotImplementedError(f"resize(mode={mode}) not implemented.")
class UpSampleLayer(nn.Module):
def __init__(
self,
mode="bicubic",
size=None,
factor=2,
align_corners=False,
):
super(UpSampleLayer, self).__init__()
self.mode = mode
self.size = val2list(size, 2) if size is not None else None
self.factor = None if self.size is not None else factor
self.align_corners = align_corners
def forward(self, x: torch.Tensor) -> torch.Tensor:
return resize(x, self.size, self.factor, self.mode, self.align_corners)
class OpSequential(nn.Module):
def __init__(self, op_list):
super(OpSequential, self).__init__()
valid_op_list = []
for op in op_list:
if op is not None:
valid_op_list.append(op)
self.op_list = nn.ModuleList(valid_op_list)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for op in self.op_list:
x = op(x)
return x