|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange, Reduce |
|
from torch import einsum, nn |
|
|
|
from .layernorm import LayerNorm2d |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def default(val, d): |
|
return val if exists(val) else d |
|
|
|
|
|
def cast_tuple(val, length=1): |
|
return val if isinstance(val, tuple) else ((val,) * length) |
|
|
|
|
|
|
|
|
|
|
|
class PreNormResidual(nn.Module): |
|
def __init__(self, dim, fn): |
|
super().__init__() |
|
self.norm = nn.LayerNorm(dim) |
|
self.fn = fn |
|
|
|
def forward(self, x): |
|
return self.fn(self.norm(x)) + x |
|
|
|
|
|
class Conv_PreNormResidual(nn.Module): |
|
def __init__(self, dim, fn): |
|
super().__init__() |
|
self.norm = LayerNorm2d(dim) |
|
self.fn = fn |
|
|
|
def forward(self, x): |
|
return self.fn(self.norm(x)) + x |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, mult=2, dropout=0.0): |
|
super().__init__() |
|
inner_dim = int(dim * mult) |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, inner_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(inner_dim, dim), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class Conv_FeedForward(nn.Module): |
|
def __init__(self, dim, mult=2, dropout=0.0): |
|
super().__init__() |
|
inner_dim = int(dim * mult) |
|
self.net = nn.Sequential( |
|
nn.Conv2d(dim, inner_dim, 1, 1, 0), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Conv2d(inner_dim, dim, 1, 1, 0), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
class Gated_Conv_FeedForward(nn.Module): |
|
def __init__(self, dim, mult=1, bias=False, dropout=0.0): |
|
super().__init__() |
|
|
|
hidden_features = int(dim * mult) |
|
|
|
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) |
|
|
|
self.dwconv = nn.Conv2d( |
|
hidden_features * 2, |
|
hidden_features * 2, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
groups=hidden_features * 2, |
|
bias=bias, |
|
) |
|
|
|
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) |
|
|
|
def forward(self, x): |
|
x = self.project_in(x) |
|
x1, x2 = self.dwconv(x).chunk(2, dim=1) |
|
x = F.gelu(x1) * x2 |
|
x = self.project_out(x) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class SqueezeExcitation(nn.Module): |
|
def __init__(self, dim, shrinkage_rate=0.25): |
|
super().__init__() |
|
hidden_dim = int(dim * shrinkage_rate) |
|
|
|
self.gate = nn.Sequential( |
|
Reduce("b c h w -> b c", "mean"), |
|
nn.Linear(dim, hidden_dim, bias=False), |
|
nn.SiLU(), |
|
nn.Linear(hidden_dim, dim, bias=False), |
|
nn.Sigmoid(), |
|
Rearrange("b c -> b c 1 1"), |
|
) |
|
|
|
def forward(self, x): |
|
return x * self.gate(x) |
|
|
|
|
|
class MBConvResidual(nn.Module): |
|
def __init__(self, fn, dropout=0.0): |
|
super().__init__() |
|
self.fn = fn |
|
self.dropsample = Dropsample(dropout) |
|
|
|
def forward(self, x): |
|
out = self.fn(x) |
|
out = self.dropsample(out) |
|
return out + x |
|
|
|
|
|
class Dropsample(nn.Module): |
|
def __init__(self, prob=0): |
|
super().__init__() |
|
self.prob = prob |
|
|
|
def forward(self, x): |
|
device = x.device |
|
|
|
if self.prob == 0.0 or (not self.training): |
|
return x |
|
|
|
keep_mask = ( |
|
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() |
|
> self.prob |
|
) |
|
return x * keep_mask / (1 - self.prob) |
|
|
|
|
|
def MBConv( |
|
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0 |
|
): |
|
hidden_dim = int(expansion_rate * dim_out) |
|
stride = 2 if downsample else 1 |
|
|
|
net = nn.Sequential( |
|
nn.Conv2d(dim_in, hidden_dim, 1), |
|
|
|
nn.GELU(), |
|
nn.Conv2d( |
|
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim |
|
), |
|
|
|
nn.GELU(), |
|
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), |
|
nn.Conv2d(hidden_dim, dim_out, 1), |
|
|
|
) |
|
|
|
if dim_in == dim_out and not downsample: |
|
net = MBConvResidual(net, dropout=dropout) |
|
|
|
return net |
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_head=32, |
|
dropout=0.0, |
|
window_size=7, |
|
with_pe=True, |
|
): |
|
super().__init__() |
|
assert ( |
|
dim % dim_head |
|
) == 0, "dimension should be divisible by dimension per head" |
|
|
|
self.heads = dim // dim_head |
|
self.scale = dim_head**-0.5 |
|
self.with_pe = with_pe |
|
|
|
self.to_qkv = nn.Linear(dim, dim * 3, bias=False) |
|
|
|
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout) |
|
) |
|
|
|
|
|
if self.with_pe: |
|
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) |
|
|
|
pos = torch.arange(window_size) |
|
grid = torch.stack(torch.meshgrid(pos, pos)) |
|
grid = rearrange(grid, "c i j -> (i j) c") |
|
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange( |
|
grid, "j ... -> 1 j ..." |
|
) |
|
rel_pos += window_size - 1 |
|
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum( |
|
dim=-1 |
|
) |
|
|
|
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) |
|
|
|
def forward(self, x): |
|
batch, height, width, window_height, window_width, _, device, h = ( |
|
*x.shape, |
|
x.device, |
|
self.heads, |
|
) |
|
|
|
|
|
|
|
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d") |
|
|
|
|
|
|
|
q, k, v = self.to_qkv(x).chunk(3, dim=-1) |
|
|
|
|
|
|
|
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v)) |
|
|
|
|
|
|
|
q = q * self.scale |
|
|
|
|
|
|
|
sim = einsum("b h i d, b h j d -> b h i j", q, k) |
|
|
|
|
|
if self.with_pe: |
|
bias = self.rel_pos_bias(self.rel_pos_indices) |
|
sim = sim + rearrange(bias, "i j h -> h i j") |
|
|
|
|
|
|
|
attn = self.attend(sim) |
|
|
|
|
|
|
|
out = einsum("b h i j, b h j d -> b h i d", attn, v) |
|
|
|
|
|
|
|
out = rearrange( |
|
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width |
|
) |
|
|
|
|
|
|
|
out = self.to_out(out) |
|
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) |
|
|
|
|
|
class Block_Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_head=32, |
|
bias=False, |
|
dropout=0.0, |
|
window_size=7, |
|
with_pe=True, |
|
): |
|
super().__init__() |
|
assert ( |
|
dim % dim_head |
|
) == 0, "dimension should be divisible by dimension per head" |
|
|
|
self.heads = dim // dim_head |
|
self.ps = window_size |
|
self.scale = dim_head**-0.5 |
|
self.with_pe = with_pe |
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) |
|
self.qkv_dwconv = nn.Conv2d( |
|
dim * 3, |
|
dim * 3, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
groups=dim * 3, |
|
bias=bias, |
|
) |
|
|
|
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) |
|
|
|
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) |
|
|
|
def forward(self, x): |
|
|
|
b, c, h, w = x.shape |
|
|
|
qkv = self.qkv_dwconv(self.qkv(x)) |
|
q, k, v = qkv.chunk(3, dim=1) |
|
|
|
|
|
|
|
q, k, v = map( |
|
lambda t: rearrange( |
|
t, |
|
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d", |
|
h=self.heads, |
|
w1=self.ps, |
|
w2=self.ps, |
|
), |
|
(q, k, v), |
|
) |
|
|
|
|
|
|
|
q = q * self.scale |
|
|
|
|
|
|
|
sim = einsum("b h i d, b h j d -> b h i j", q, k) |
|
|
|
|
|
attn = self.attend(sim) |
|
|
|
|
|
|
|
out = einsum("b h i j, b h j d -> b h i d", attn, v) |
|
|
|
|
|
out = rearrange( |
|
out, |
|
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)", |
|
x=h // self.ps, |
|
y=w // self.ps, |
|
head=self.heads, |
|
w1=self.ps, |
|
w2=self.ps, |
|
) |
|
|
|
out = self.to_out(out) |
|
return out |
|
|
|
|
|
class Channel_Attention(nn.Module): |
|
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): |
|
super(Channel_Attention, self).__init__() |
|
self.heads = heads |
|
|
|
self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) |
|
|
|
self.ps = window_size |
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) |
|
self.qkv_dwconv = nn.Conv2d( |
|
dim * 3, |
|
dim * 3, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
groups=dim * 3, |
|
bias=bias, |
|
) |
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) |
|
|
|
def forward(self, x): |
|
b, c, h, w = x.shape |
|
|
|
qkv = self.qkv_dwconv(self.qkv(x)) |
|
qkv = qkv.chunk(3, dim=1) |
|
|
|
q, k, v = map( |
|
lambda t: rearrange( |
|
t, |
|
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)", |
|
ph=self.ps, |
|
pw=self.ps, |
|
head=self.heads, |
|
), |
|
qkv, |
|
) |
|
|
|
q = F.normalize(q, dim=-1) |
|
k = F.normalize(k, dim=-1) |
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.temperature |
|
attn = attn.softmax(dim=-1) |
|
out = attn @ v |
|
|
|
out = rearrange( |
|
out, |
|
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)", |
|
h=h // self.ps, |
|
w=w // self.ps, |
|
ph=self.ps, |
|
pw=self.ps, |
|
head=self.heads, |
|
) |
|
|
|
out = self.project_out(out) |
|
|
|
return out |
|
|
|
|
|
class Channel_Attention_grid(nn.Module): |
|
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): |
|
super(Channel_Attention_grid, self).__init__() |
|
self.heads = heads |
|
|
|
self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) |
|
|
|
self.ps = window_size |
|
|
|
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) |
|
self.qkv_dwconv = nn.Conv2d( |
|
dim * 3, |
|
dim * 3, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
groups=dim * 3, |
|
bias=bias, |
|
) |
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) |
|
|
|
def forward(self, x): |
|
b, c, h, w = x.shape |
|
|
|
qkv = self.qkv_dwconv(self.qkv(x)) |
|
qkv = qkv.chunk(3, dim=1) |
|
|
|
q, k, v = map( |
|
lambda t: rearrange( |
|
t, |
|
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)", |
|
ph=self.ps, |
|
pw=self.ps, |
|
head=self.heads, |
|
), |
|
qkv, |
|
) |
|
|
|
q = F.normalize(q, dim=-1) |
|
k = F.normalize(k, dim=-1) |
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.temperature |
|
attn = attn.softmax(dim=-1) |
|
out = attn @ v |
|
|
|
out = rearrange( |
|
out, |
|
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)", |
|
h=h // self.ps, |
|
w=w // self.ps, |
|
ph=self.ps, |
|
pw=self.ps, |
|
head=self.heads, |
|
) |
|
|
|
out = self.project_out(out) |
|
|
|
return out |
|
|
|
|
|
class OSA_Block(nn.Module): |
|
def __init__( |
|
self, |
|
channel_num=64, |
|
bias=True, |
|
ffn_bias=True, |
|
window_size=8, |
|
with_pe=False, |
|
dropout=0.0, |
|
): |
|
super(OSA_Block, self).__init__() |
|
|
|
w = window_size |
|
|
|
self.layer = nn.Sequential( |
|
MBConv( |
|
channel_num, |
|
channel_num, |
|
downsample=False, |
|
expansion_rate=1, |
|
shrinkage_rate=0.25, |
|
), |
|
Rearrange( |
|
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w |
|
), |
|
PreNormResidual( |
|
channel_num, |
|
Attention( |
|
dim=channel_num, |
|
dim_head=channel_num // 4, |
|
dropout=dropout, |
|
window_size=window_size, |
|
with_pe=with_pe, |
|
), |
|
), |
|
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"), |
|
Conv_PreNormResidual( |
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) |
|
), |
|
|
|
Conv_PreNormResidual( |
|
channel_num, |
|
Channel_Attention( |
|
dim=channel_num, heads=4, dropout=dropout, window_size=window_size |
|
), |
|
), |
|
Conv_PreNormResidual( |
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) |
|
), |
|
Rearrange( |
|
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w |
|
), |
|
PreNormResidual( |
|
channel_num, |
|
Attention( |
|
dim=channel_num, |
|
dim_head=channel_num // 4, |
|
dropout=dropout, |
|
window_size=window_size, |
|
with_pe=with_pe, |
|
), |
|
), |
|
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"), |
|
Conv_PreNormResidual( |
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) |
|
), |
|
|
|
Conv_PreNormResidual( |
|
channel_num, |
|
Channel_Attention_grid( |
|
dim=channel_num, heads=4, dropout=dropout, window_size=window_size |
|
), |
|
), |
|
Conv_PreNormResidual( |
|
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) |
|
), |
|
) |
|
|
|
def forward(self, x): |
|
out = self.layer(x) |
|
return out |
|
|