""" @Date: 2021/09/01 @description: """ import warnings import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, **kwargs): return self.fn(self.norm(x), **kwargs) # compatibility pytorch < 1.4 class GELU(nn.Module): def forward(self, input): return F.gelu(input) class Attend(nn.Module): def __init__(self, dim=None): super().__init__() self.dim = dim def forward(self, input): return F.softmax(input, dim=self.dim, dtype=input.dtype) class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class RelativePosition(nn.Module): def __init__(self, heads, patch_num=None, rpe=None): super().__init__() self.rpe = rpe self.heads = heads self.patch_num = patch_num if rpe == 'lr_parameter': # -255 ~ 0 ~ 255 all count : patch * 2 - 1 count = patch_num * 2 - 1 self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) nn.init.xavier_uniform_(self.rpe_table) elif rpe == 'lr_parameter_mirror': # 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1 count = patch_num // 2 + 1 self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) nn.init.xavier_uniform_(self.rpe_table) elif rpe == 'lr_parameter_half': # -127 ~ 0 ~ 128 all count : patch count = patch_num self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) nn.init.xavier_uniform_(self.rpe_table) elif rpe == 'fix_angle': # 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1 count = patch_num // 2 + 1 # we think that closer proximity should have stronger relationships rpe_table = (torch.arange(count, 0, -1) / count)[..., None].repeat(1, heads) self.register_buffer('rpe_table', rpe_table) def get_relative_pos_embed(self): range_vec = torch.arange(self.patch_num) distance_mat = range_vec[None, :] - range_vec[:, None] if self.rpe == 'lr_parameter': # -255 ~ 0 ~ 255 -> 0 ~ 255 ~ 255 + 255 distance_mat += self.patch_num - 1 # remove negative return self.rpe_table[distance_mat].permute(2, 0, 1)[None] elif self.rpe == 'lr_parameter_mirror' or self.rpe == 'fix_angle': distance_mat[distance_mat < 0] = -distance_mat[distance_mat < 0] # mirror distance_mat[distance_mat > self.patch_num // 2] = self.patch_num - distance_mat[ distance_mat > self.patch_num // 2] # remove repeat return self.rpe_table[distance_mat].permute(2, 0, 1)[None] elif self.rpe == 'lr_parameter_half': distance_mat[distance_mat > self.patch_num // 2] = distance_mat[ distance_mat > self.patch_num // 2] - self.patch_num # remove repeat > 128 exp: 129 -> -127 distance_mat[distance_mat < -self.patch_num // 2 + 1] = distance_mat[ distance_mat < -self.patch_num // 2 + 1] + self.patch_num # remove repeat < -127 exp: -128 -> 128 # -127 ~ 0 ~ 128 -> 0 ~ 0 ~ 127 + 127 + 128 distance_mat += self.patch_num//2 - 1 # remove negative return self.rpe_table[distance_mat].permute(2, 0, 1)[None] def forward(self, attn): return attn + self.get_relative_pos_embed() class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0., patch_num=None, rpe=None, rpe_pos=1): """ :param dim: :param heads: :param dim_head: :param dropout: :param patch_num: :param rpe: relative position embedding """ super().__init__() self.relative_pos_embed = None if patch_num is None or rpe is None else RelativePosition(heads, patch_num, rpe) inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.rpe_pos = rpe_pos self.attend = Attend(dim=-1) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale if self.rpe_pos == 0: if self.relative_pos_embed is not None: dots = self.relative_pos_embed(dots) attn = self.attend(dots) if self.rpe_pos == 1: if self.relative_pos_embed is not None: attn = self.relative_pos_embed(attn) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class AbsolutePosition(nn.Module): def __init__(self, dim, dropout=0., patch_num=None, ape=None): super().__init__() self.ape = ape if ape == 'lr_parameter': self.absolute_pos_embed = nn.Parameter(torch.zeros(1, patch_num, dim)) trunc_normal_(self.absolute_pos_embed, std=.02) elif ape == 'fix_angle': angle = torch.arange(0, patch_num, dtype=torch.float) / patch_num * (math.pi * 2) self.absolute_pos_embed = torch.sin(angle)[..., None].repeat(1, dim)[None] def forward(self, x): return x + self.absolute_pos_embed class WinAttention(nn.Module): def __init__(self, dim, win_size=8, shift=0, heads=8, dim_head=64, dropout=0., rpe=None, rpe_pos=1): super().__init__() self.win_size = win_size self.shift = shift self.attend = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, patch_num=win_size, rpe=None if rpe is None else 'lr_parameter', rpe_pos=rpe_pos) def forward(self, x): b = x.shape[0] if self.shift != 0: x = torch.roll(x, shifts=self.shift, dims=-2) x = rearrange(x, 'b (m w) d -> (b m) w d', w=self.win_size) # split windows out = self.attend(x) out = rearrange(out, '(b m) w d -> b (m w) d ', b=b) # recover windows if self.shift != 0: out = torch.roll(out, shifts=-self.shift, dims=-2) return out class Conv(nn.Module): def __init__(self, dim, dropout=0.): super().__init__() self.dim = dim self.net = nn.Sequential( nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0), nn.Dropout(dropout) ) def forward(self, x): x = x.transpose(1, 2) x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1) x = self.net(x) return x.transpose(1, 2)