|
""" |
|
Adapted from 2020 Ross Wightman |
|
https://github.com/rwightman/pytorch-image-models |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
from pathlib import Path |
|
|
|
import torch.nn.functional as F |
|
|
|
from timm.models.layers import DropPath |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, hidden_dim, dropout, out_dim=None): |
|
super().__init__() |
|
self.fc1 = nn.Linear(dim, hidden_dim) |
|
self.act = nn.GELU() |
|
if out_dim is None: |
|
out_dim = dim |
|
self.fc2 = nn.Linear(hidden_dim, out_dim) |
|
self.drop = nn.Dropout(dropout) |
|
|
|
@property |
|
def unwrapped(self): |
|
return self |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.drop(x) |
|
x = self.fc2(x) |
|
x = self.drop(x) |
|
return x |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, heads, dropout): |
|
super().__init__() |
|
self.heads = heads |
|
head_dim = dim // heads |
|
self.scale = head_dim ** -0.5 |
|
self.attn = None |
|
|
|
self.qkv = nn.Linear(dim, dim * 3) |
|
self.attn_drop = nn.Dropout(dropout) |
|
self.proj = nn.Linear(dim, dim) |
|
self.proj_drop = nn.Dropout(dropout) |
|
|
|
@property |
|
def unwrapped(self): |
|
return self |
|
|
|
def forward(self, x, mask=None): |
|
B, N, C = x.shape |
|
qkv = ( |
|
self.qkv(x) |
|
.reshape(B, N, 3, self.heads, C // self.heads) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
q, k, v = ( |
|
qkv[0], |
|
qkv[1], |
|
qkv[2], |
|
) |
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
|
|
return x, attn |
|
|
|
|
|
class AttentionQK(nn.Module): |
|
def __init__(self, dim, heads=1, dropout=0.): |
|
super().__init__() |
|
self.heads = heads |
|
head_dim = dim // heads |
|
self.scale = head_dim ** -0.5 |
|
self.attn = None |
|
|
|
self.qk = nn.Linear(dim, dim * 2) |
|
self.attn_drop = nn.Dropout(dropout) |
|
|
|
@property |
|
def unwrapped(self): |
|
return self |
|
|
|
def forward(self, x): |
|
B, N, C = x.shape |
|
qkv = ( |
|
self.qk(x) |
|
.reshape(B, N, 2, self.heads, C // self.heads) |
|
.permute(2, 0, 3, 1, 4) |
|
) |
|
q, k = ( |
|
qkv[0], |
|
qkv[1] |
|
) |
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
|
attn = attn.softmax(dim=-1) |
|
|
|
return attn |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, dim, heads, mlp_dim, dropout, drop_path): |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(dim) |
|
self.norm2 = nn.LayerNorm(dim) |
|
self.attn = Attention(dim, heads, dropout) |
|
self.mlp = FeedForward(dim, mlp_dim, dropout) |
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
def forward(self, x, mask=None, return_attention=False): |
|
y, attn = self.attn(self.norm1(x), mask) |
|
if return_attention: |
|
return attn |
|
x = x + self.drop_path(y) |
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
return x |
|
|