|
import torch |
|
import torch.nn as nn |
|
|
|
class SimPool(nn.Module): |
|
def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
self.scale = qk_scale or head_dim ** -0.5 |
|
|
|
self.norm_patches = nn.LayerNorm(dim, eps=1e-6) |
|
|
|
self.wq = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.wk = nn.Linear(dim, dim, bias=qkv_bias) |
|
|
|
if gamma is not None: |
|
self.gamma = torch.tensor([gamma], device='cuda') |
|
if use_beta: |
|
self.beta = nn.Parameter(torch.tensor([0.0], device='cuda')) |
|
self.eps = torch.tensor([1e-6], device='cuda') |
|
|
|
self.gamma = gamma |
|
self.use_beta = use_beta |
|
|
|
def prepare_input(self, x): |
|
if len(x.shape) == 3: |
|
|
|
|
|
B, N, d = x.shape |
|
gap_cls = x.mean(-2) |
|
gap_cls = gap_cls.unsqueeze(1) |
|
return gap_cls, x |
|
if len(x.shape) == 4: |
|
|
|
|
|
B, d, H, W = x.shape |
|
gap_cls = x.mean([-2, -1]) |
|
x = x.reshape(B, d, H*W).permute(0, 2, 1) |
|
gap_cls = gap_cls.unsqueeze(1) |
|
return gap_cls, x |
|
else: |
|
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}") |
|
|
|
def forward(self, x): |
|
|
|
gap_cls, x = self.prepare_input(x) |
|
|
|
|
|
q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x) |
|
|
|
|
|
Bq, Nq, dq = q.shape |
|
Bk, Nk, dk = k.shape |
|
Bv, Nv, dv = v.shape |
|
|
|
|
|
assert Bq == Bk == Bv |
|
assert dq == dk == dv |
|
|
|
|
|
qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1, 3) |
|
kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1, 3) |
|
|
|
vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1, 3) |
|
|
|
|
|
attn = (qq @ kk.transpose(-2, -1)) * self.scale |
|
|
|
attn = attn.softmax(dim=-1) |
|
|
|
|
|
if self.gamma is not None: |
|
|
|
x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma), 1/self.gamma) |
|
|
|
if self.use_beta: |
|
x = x + self.beta |
|
else: |
|
|
|
x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq) |
|
|
|
return attn |