zhigangjiang's picture
no message
88b0dcb
raw
history blame
8.91 kB
"""
@Date: 2021/09/01
@description:
"""
import warnings
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
# compatibility pytorch < 1.4
class GELU(nn.Module):
def forward(self, input):
return F.gelu(input)
class Attend(nn.Module):
def __init__(self, dim=None):
super().__init__()
self.dim = dim
def forward(self, input):
return F.softmax(input, dim=self.dim, dtype=input.dtype)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class RelativePosition(nn.Module):
def __init__(self, heads, patch_num=None, rpe=None):
super().__init__()
self.rpe = rpe
self.heads = heads
self.patch_num = patch_num
if rpe == 'lr_parameter':
# -255 ~ 0 ~ 255 all count : patch * 2 - 1
count = patch_num * 2 - 1
self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
nn.init.xavier_uniform_(self.rpe_table)
elif rpe == 'lr_parameter_mirror':
# 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1
count = patch_num // 2 + 1
self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
nn.init.xavier_uniform_(self.rpe_table)
elif rpe == 'lr_parameter_half':
# -127 ~ 0 ~ 128 all count : patch
count = patch_num
self.rpe_table = nn.Parameter(torch.Tensor(count, heads))
nn.init.xavier_uniform_(self.rpe_table)
elif rpe == 'fix_angle':
# 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1
count = patch_num // 2 + 1
# we think that closer proximity should have stronger relationships
rpe_table = (torch.arange(count, 0, -1) / count)[..., None].repeat(1, heads)
self.register_buffer('rpe_table', rpe_table)
def get_relative_pos_embed(self):
range_vec = torch.arange(self.patch_num)
distance_mat = range_vec[None, :] - range_vec[:, None]
if self.rpe == 'lr_parameter':
# -255 ~ 0 ~ 255 -> 0 ~ 255 ~ 255 + 255
distance_mat += self.patch_num - 1 # remove negative
return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
elif self.rpe == 'lr_parameter_mirror' or self.rpe == 'fix_angle':
distance_mat[distance_mat < 0] = -distance_mat[distance_mat < 0] # mirror
distance_mat[distance_mat > self.patch_num // 2] = self.patch_num - distance_mat[
distance_mat > self.patch_num // 2] # remove repeat
return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
elif self.rpe == 'lr_parameter_half':
distance_mat[distance_mat > self.patch_num // 2] = distance_mat[
distance_mat > self.patch_num // 2] - self.patch_num # remove repeat > 128 exp: 129 -> -127
distance_mat[distance_mat < -self.patch_num // 2 + 1] = distance_mat[
distance_mat < -self.patch_num // 2 + 1] + self.patch_num # remove repeat < -127 exp: -128 -> 128
# -127 ~ 0 ~ 128 -> 0 ~ 0 ~ 127 + 127 + 128
distance_mat += self.patch_num//2 - 1 # remove negative
return self.rpe_table[distance_mat].permute(2, 0, 1)[None]
def forward(self, attn):
return attn + self.get_relative_pos_embed()
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0., patch_num=None, rpe=None, rpe_pos=1):
"""
:param dim:
:param heads:
:param dim_head:
:param dropout:
:param patch_num:
:param rpe: relative position embedding
"""
super().__init__()
self.relative_pos_embed = None if patch_num is None or rpe is None else RelativePosition(heads, patch_num, rpe)
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.rpe_pos = rpe_pos
self.attend = Attend(dim=-1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
if self.rpe_pos == 0:
if self.relative_pos_embed is not None:
dots = self.relative_pos_embed(dots)
attn = self.attend(dots)
if self.rpe_pos == 1:
if self.relative_pos_embed is not None:
attn = self.relative_pos_embed(attn)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class AbsolutePosition(nn.Module):
def __init__(self, dim, dropout=0., patch_num=None, ape=None):
super().__init__()
self.ape = ape
if ape == 'lr_parameter':
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, patch_num, dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
elif ape == 'fix_angle':
angle = torch.arange(0, patch_num, dtype=torch.float) / patch_num * (math.pi * 2)
self.absolute_pos_embed = torch.sin(angle)[..., None].repeat(1, dim)[None]
def forward(self, x):
return x + self.absolute_pos_embed
class WinAttention(nn.Module):
def __init__(self, dim, win_size=8, shift=0, heads=8, dim_head=64, dropout=0., rpe=None, rpe_pos=1):
super().__init__()
self.win_size = win_size
self.shift = shift
self.attend = Attention(dim, heads=heads, dim_head=dim_head,
dropout=dropout, patch_num=win_size, rpe=None if rpe is None else 'lr_parameter',
rpe_pos=rpe_pos)
def forward(self, x):
b = x.shape[0]
if self.shift != 0:
x = torch.roll(x, shifts=self.shift, dims=-2)
x = rearrange(x, 'b (m w) d -> (b m) w d', w=self.win_size) # split windows
out = self.attend(x)
out = rearrange(out, '(b m) w d -> b (m w) d ', b=b) # recover windows
if self.shift != 0:
out = torch.roll(out, shifts=-self.shift, dims=-2)
return out
class Conv(nn.Module):
def __init__(self, dim, dropout=0.):
super().__init__()
self.dim = dim
self.net = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0),
nn.Dropout(dropout)
)
def forward(self, x):
x = x.transpose(1, 2)
x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1)
x = self.net(x)
return x.transpose(1, 2)