alps / unitable /src /model /components.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
raw
history blame
6.11 kB
from typing import Optional, Tuple
import torch
from torch import nn, Tensor
from torchvision.ops.misc import Conv2dNormActivation
__all__ = [
"ImgCnnBackbone",
"ImgLinearBackbone",
"ImgConvStemBackbone",
"PositionEmbedding",
"Encoder",
"Decoder",
"TokenEmbedding",
]
class ImgCnnBackbone(nn.Module):
def __init__(
self,
backbone: nn.Module,
output_channels: int,
d_model: int,
drop_layer: Tuple = None,
) -> None:
super().__init__()
# drop layers for classification & maxpooling for higher feature resolution
layers = list(backbone.children())
nlayer = len(layers)
keep_layer = set([i for i in range(nlayer)]) - set(drop_layer)
backbone = [layers[i] for i in keep_layer]
self.backbone = nn.Sequential(*backbone)
self.proj = nn.Linear(output_channels, d_model)
self.channels = output_channels
def forward(self, x: Tensor) -> Tensor:
x = self.backbone(x)
x = x.flatten(start_dim=-2).transpose(1, 2)
assert x.shape[-1] == self.channels, "Image channels size mismatch."
x = self.proj(x)
return x
class ImgLinearBackbone(nn.Module):
def __init__(
self,
d_model: int,
patch_size: int,
in_chan: int = 3,
) -> None:
super().__init__()
self.conv_proj = nn.Conv2d(
in_chan, out_channels=d_model, kernel_size=patch_size, stride=patch_size
)
self.d_model = d_model
def forward(self, x: Tensor) -> Tensor:
x = self.conv_proj(x)
x = x.flatten(start_dim=-2).transpose(1, 2)
return x
class ImgConvStemBackbone(nn.Module):
def __init__(
self,
d_model: int,
downsample_factor: int,
output_channels: int,
kernel_size: int,
) -> None:
super().__init__()
assert downsample_factor % 2 == 0
assert output_channels % (downsample_factor // 2) == 0
input_channels = output_channels // (downsample_factor // 2)
layers = [
Conv2dNormActivation(
3, input_channels, kernel_size=kernel_size, stride=2, padding=1
)
]
while input_channels != output_channels:
layers.append(
Conv2dNormActivation(
input_channels,
input_channels * 2,
kernel_size=kernel_size,
stride=2,
padding=1,
)
)
input_channels = input_channels * 2
layers.append(nn.Conv2d(output_channels, d_model, kernel_size=1))
self.conv_stem = nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
x = self.conv_stem(x)
x = x.flatten(start_dim=-2).transpose(1, 2)
return x
class Encoder(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dropout: float,
activation: str,
norm_first: bool,
nlayer: int,
ff_ratio: int = 4,
) -> None:
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model,
nhead=nhead,
dim_feedforward=ff_ratio * d_model,
dropout=dropout,
activation=activation,
batch_first=True,
norm_first=norm_first,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=nlayer)
def forward(self, x: Tensor) -> Tensor:
x = self.encoder(x)
return x
class Decoder(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dropout: float,
activation: str,
norm_first: bool,
nlayer: int,
ff_ratio: int = 4,
) -> None:
super().__init__()
decoder_layer = nn.TransformerDecoderLayer(
d_model,
nhead,
dim_feedforward=ff_ratio * d_model,
dropout=dropout,
activation=activation,
batch_first=True,
norm_first=norm_first,
)
self.decoder = nn.TransformerDecoder(decoder_layer, nlayer)
def forward(
self, x: Tensor, memory: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor
) -> Tensor:
x = self.decoder(
x, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_padding_mask
)
return x
class PositionEmbedding(nn.Module):
def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
super().__init__()
self.embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor) -> Tensor:
# assume x is batch first
out = self.embedding(torch.arange(x.shape[1], device=x.device))
return self.dropout(out + x)
class TokenEmbedding(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int,
padding_idx: int,
) -> None:
super().__init__()
assert vocab_size > 0
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
def forward(self, x: Tensor) -> Tensor:
return self.embedding(x)
class PrintLayer(nn.Module):
"""Only for debugging when loss is nan."""
def __init__(self):
super().__init__()
def forward(self, x):
print(
"torch.isfinite(x).all(): {}, min. {:.5f}, max. {:.5f}".format(
torch.isfinite(x).all(), x.min(), x.max()
)
)
return x
if __name__ == "__main__":
from torchvision import models
x = torch.rand(1, 3, 392, 392)
model = ImgConvStemBackbone(
d_model=512, downsample_factor=16, output_channels=64, kernel_size=5
)
y = model(x)
print(model)
print(y.shape)
model = ImgCnnBackbone(
backbone=models.resnet34(),
output_channels=512,
d_model=512,
drop_layer=(3, 8, 9),
)
# print(model)
y = model(x)
print(y.shape)