|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch |
|
from torch import Tensor |
|
from typing import Optional |
|
|
|
class Attention(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_heads: int, |
|
dropout_prob: float = 0 |
|
): |
|
super().__init__() |
|
|
|
self.use_sdp = int(torch.__version__[0]) > 1 |
|
|
|
self.query = nn.Linear(dim, dim) |
|
self.key = nn.Linear(dim, dim) |
|
self.value = nn.Linear(dim, dim) |
|
self.out = nn.Linear(dim, dim) |
|
|
|
self.dropout_prob = dropout_prob |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
self.scale = self.head_dim**-0.5 |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
attn_mask: Optional[Tensor] = None, |
|
context: Optional[Tensor] = None, |
|
is_causal: bool = False, |
|
) -> Tensor: |
|
|
|
query = self.reshape(self.query(x)) |
|
key = self.reshape(self.key(x if context is None else context)) |
|
value = self.reshape(self.value(x if context is None else context)) |
|
|
|
if self.use_sdp: |
|
x = F.scaled_dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
attn_mask, |
|
dropout_p=self.dropout_prob if self.training else 0, |
|
is_causal=is_causal, |
|
) |
|
else: |
|
attn = query @ key.transpose(-2, -1) * self.scale |
|
if attn_mask is not None: |
|
attn += attn_mask |
|
|
|
attn = attn.softmax(dim=-1) |
|
x = attn @ value |
|
|
|
return self.out(x.transpose(2, 1).flatten(2)) |
|
|
|
def reshape(self, x: Tensor) -> Tensor: |
|
batch_size, seq_len, _ = x.shape |
|
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
return x.transpose(2, 1) |
|
|
|
|
|
class MLP(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_expand_factor: int = 4 |
|
): |
|
super().__init__() |
|
|
|
self.hidden_layer = nn.Linear(dim, dim * dim_expand_factor) |
|
self.output_layer = nn.Linear(dim * dim_expand_factor, dim) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
x = F.gelu(self.hidden_layer(x)) |
|
return self.output_layer(x) |
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
init_values: float = 1e-5, |
|
inplace: bool = False |
|
): |
|
super().__init__() |
|
self.weight = nn.Parameter(init_values * torch.ones(dim)) |
|
self.inplace = inplace |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return x.mul_(self.weight) if self.inplace else x * self.weight |
|
|
|
|
|
class VisionEncoderBlock(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_heads: int |
|
): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
|
self.attn = Attention(dim, num_heads) |
|
self.ls1 = LayerScale(dim) |
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
|
self.mlp = MLP(dim) |
|
self.ls2 = LayerScale(dim) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
x = x + self.ls1(self.attn(self.norm1(x))) |
|
x = x + self.ls2(self.mlp(self.norm2(x))) |
|
return x |
|
|
|
|
|
class VisionEncoder(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
patch_size: int, |
|
num_layers: int, |
|
num_heads: int, |
|
): |
|
super().__init__() |
|
|
|
self.n_patch = 224 // patch_size |
|
self.seq_len = self.n_patch ** 2 |
|
self.patch_size = patch_size |
|
|
|
self.patch_embed = nn.Conv2d(3, dim, patch_size, patch_size) |
|
self.pos_embed = nn.Parameter(torch.randn(1, self.seq_len, dim) * 0.02) |
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) |
|
self.interpolate_offset = 0.1 |
|
self.interpolate_antialias = False |
|
|
|
self.blocks = nn.Sequential( |
|
*[ |
|
VisionEncoderBlock(dim, num_heads) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
self.norm = nn.LayerNorm(dim, eps=1e-6) |
|
|
|
def interpolate_pos_encoding(self, x, h, w): |
|
previous_dtype = x.dtype |
|
|
|
if x.shape[1] == self.seq_len and w == h: |
|
return self.pos_embed |
|
|
|
pos_embed = self.pos_embed.float() |
|
|
|
dim = x.shape[-1] |
|
w0 = w // self.patch_size |
|
h0 = h // self.patch_size |
|
|
|
|
|
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset |
|
sx, sy = float(w0) / self.n_patch, float(h0) / self.n_patch |
|
|
|
pos_embed = nn.functional.interpolate( |
|
pos_embed.reshape(1, self.n_patch, self.n_patch, dim).permute(0, 3, 1, 2), |
|
scale_factor=(sy, sx), |
|
mode="bicubic", |
|
antialias=self.interpolate_antialias, |
|
) |
|
|
|
return pos_embed.to(previous_dtype).flatten(start_dim=2).transpose(2, 1) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
h, w = x.shape[2:] |
|
x = self.patch_embed(x).flatten(start_dim=2).transpose(2, 1) |
|
x = x + self.interpolate_pos_encoding(x, h, w) |
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
x = self.blocks(x) |
|
return self.norm(x) |