DaS / segmenter_model /blocks.py
vobecant
Initial commit
dd78229
raw
history blame contribute delete
No virus
3.3 kB
"""
Adapted from 2020 Ross Wightman
https://github.com/rwightman/pytorch-image-models
"""
import torch
import torch.nn as nn
from einops import rearrange
from pathlib import Path
import torch.nn.functional as F
from timm.models.layers import DropPath
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout, out_dim=None):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.act = nn.GELU()
if out_dim is None:
out_dim = dim
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.drop = nn.Dropout(dropout)
@property
def unwrapped(self):
return self
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, heads, dropout):
super().__init__()
self.heads = heads
head_dim = dim // heads
self.scale = head_dim ** -0.5
self.attn = None
self.qkv = nn.Linear(dim, dim * 3)
self.attn_drop = nn.Dropout(dropout)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(dropout)
@property
def unwrapped(self):
return self
def forward(self, x, mask=None):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.heads, C // self.heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class AttentionQK(nn.Module):
def __init__(self, dim, heads=1, dropout=0.):
super().__init__()
self.heads = heads
head_dim = dim // heads
self.scale = head_dim ** -0.5
self.attn = None
self.qk = nn.Linear(dim, dim * 2)
self.attn_drop = nn.Dropout(dropout)
@property
def unwrapped(self):
return self
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qk(x)
.reshape(B, N, 2, self.heads, C // self.heads)
.permute(2, 0, 3, 1, 4)
)
q, k = (
qkv[0],
qkv[1]
)
attn = (q @ k.transpose(-2, -1)) * self.scale
# attn = attn.sigmoid()
attn = attn.softmax(dim=-1)
return attn
class Block(nn.Module):
def __init__(self, dim, heads, mlp_dim, dropout, drop_path):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.attn = Attention(dim, heads, dropout)
self.mlp = FeedForward(dim, mlp_dim, dropout)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x, mask=None, return_attention=False):
y, attn = self.attn(self.norm1(x), mask)
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x