|
import math |
|
from collections import OrderedDict |
|
from functools import partial |
|
from typing import Any, Callable, List, NamedTuple, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
try: |
|
from torch.hub import load_state_dict_from_url |
|
except ImportError: |
|
from torch.utils.model_zoo import load_url as load_state_dict_from_url |
|
|
|
|
|
model_urls = { |
|
"vit_b_16": "https://download.pytorch.org/models/vit_b_16-c867db91.pth", |
|
"vit_b_32": "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", |
|
"vit_l_16": "https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", |
|
"vit_l_32": "https://download.pytorch.org/models/vit_l_32-c7638314.pth", |
|
} |
|
|
|
|
|
class MLPBlock(nn.Sequential): |
|
"""Transformer MLP block.""" |
|
|
|
def __init__(self, in_dim: int, mlp_dim: int, dropout: float): |
|
super().__init__() |
|
self.linear_1 = nn.Linear(in_dim, mlp_dim) |
|
self.act = nn.GELU() |
|
self.dropout_1 = nn.Dropout(dropout) |
|
self.linear_2 = nn.Linear(mlp_dim, in_dim) |
|
self.dropout_2 = nn.Dropout(dropout) |
|
|
|
nn.init.xavier_uniform_(self.linear_1.weight) |
|
nn.init.xavier_uniform_(self.linear_2.weight) |
|
nn.init.normal_(self.linear_1.bias, std=1e-6) |
|
nn.init.normal_(self.linear_2.bias, std=1e-6) |
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
"""Transformer encoder block.""" |
|
|
|
def __init__( |
|
self, |
|
num_heads: int, |
|
hidden_dim: int, |
|
mlp_dim: int, |
|
dropout: float, |
|
attention_dropout: float, |
|
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
|
): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
|
|
|
|
self.ln_1 = norm_layer(hidden_dim) |
|
self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.ln_2 = norm_layer(hidden_dim) |
|
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout) |
|
|
|
def forward(self, input: torch.Tensor): |
|
torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}") |
|
x = self.ln_1(input) |
|
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) |
|
x = self.dropout(x) |
|
x = x + input |
|
|
|
y = self.ln_2(x) |
|
y = self.mlp(y) |
|
return x + y |
|
|
|
|
|
class Encoder(nn.Module): |
|
"""Transformer Model Encoder for sequence to sequence translation.""" |
|
|
|
def __init__( |
|
self, |
|
seq_length: int, |
|
num_layers: int, |
|
num_heads: int, |
|
hidden_dim: int, |
|
mlp_dim: int, |
|
dropout: float, |
|
attention_dropout: float, |
|
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
|
): |
|
super().__init__() |
|
|
|
|
|
self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)) |
|
self.dropout = nn.Dropout(dropout) |
|
layers: OrderedDict[str, nn.Module] = OrderedDict() |
|
for i in range(num_layers): |
|
layers[f"encoder_layer_{i}"] = EncoderBlock( |
|
num_heads, |
|
hidden_dim, |
|
mlp_dim, |
|
dropout, |
|
attention_dropout, |
|
norm_layer, |
|
) |
|
self.layers = nn.Sequential(layers) |
|
self.ln = norm_layer(hidden_dim) |
|
|
|
def forward(self, input: torch.Tensor): |
|
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") |
|
input = input + self.pos_embedding |
|
return self.ln(self.layers(self.dropout(input))) |
|
|
|
|
|
class FeatureTransformer(nn.Module): |
|
""" |
|
Feaure Transformer |
|
""" |
|
def __init__( |
|
self, |
|
seq_length: int = 16, |
|
num_layers: int = 2, |
|
num_heads: int = 4, |
|
hidden_dim: int = 768, |
|
mlp_dim: int = 768, |
|
dropout: float = 0.0, |
|
attention_dropout: float = 0.0, |
|
num_classes: int = 1, |
|
representation_size: Optional[int] = None, |
|
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), |
|
) -> None: |
|
super().__init__() |
|
|
|
self.hidden_dim = hidden_dim |
|
self.mlp_dim = mlp_dim |
|
self.attention_dropout = attention_dropout |
|
self.dropout = dropout |
|
self.num_classes = num_classes |
|
self.representation_size = representation_size |
|
self.norm_layer = norm_layer |
|
|
|
|
|
self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) |
|
seq_length += 1 |
|
|
|
self.encoder = Encoder( |
|
seq_length, |
|
num_layers, |
|
num_heads, |
|
hidden_dim, |
|
mlp_dim, |
|
dropout, |
|
attention_dropout, |
|
norm_layer, |
|
) |
|
self.seq_length = seq_length |
|
|
|
heads_layers: OrderedDict[str, nn.Module] = OrderedDict() |
|
if representation_size is None: |
|
heads_layers["head"] = nn.Linear(hidden_dim, num_classes) |
|
else: |
|
heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size) |
|
heads_layers["act"] = nn.Tanh() |
|
heads_layers["head"] = nn.Linear(representation_size, num_classes) |
|
|
|
self.heads = nn.Sequential(heads_layers) |
|
|
|
if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear): |
|
fan_in = self.heads.pre_logits.in_features |
|
nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in)) |
|
nn.init.zeros_(self.heads.pre_logits.bias) |
|
|
|
if isinstance(self.heads.head, nn.Linear): |
|
nn.init.zeros_(self.heads.head.weight) |
|
nn.init.zeros_(self.heads.head.bias) |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
batch_class_token = self.class_token.expand(x.shape[0], -1, -1) |
|
x = torch.cat([batch_class_token, x], dim=1) |
|
|
|
x = self.encoder(x) |
|
|
|
|
|
x = x[:, 0] |
|
x = self.heads(x) |
|
|
|
return x |
|
|