|
import torch |
|
import numpy as np |
|
from torch import nn, einsum |
|
from torch.nn import functional as F |
|
from einops.layers.torch import Rearrange |
|
from einops import rearrange, reduce |
|
from math import ceil |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, hidden_dim, dropout): |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(dim, hidden_dim), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_dim, dim), |
|
nn.Dropout(dropout) |
|
) |
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def moore_penrose_iter_pinv(x, iters = 6): |
|
device = x.device |
|
|
|
abs_x = torch.abs(x) |
|
col = abs_x.sum(dim = -1) |
|
row = abs_x.sum(dim = -2) |
|
z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row)) |
|
|
|
I = torch.eye(x.shape[-1], device = device) |
|
I = rearrange(I, 'i j -> () i j') |
|
|
|
for _ in range(iters): |
|
xz = x @ z |
|
z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))) |
|
|
|
return z |
|
|
|
|
|
|
|
class NystromAttention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
dim_head = 64, |
|
heads = 8, |
|
num_landmarks = 256, |
|
pinv_iterations = 6, |
|
residual = True, |
|
residual_conv_kernel = 33, |
|
eps = 1e-8, |
|
dropout = 0. |
|
): |
|
super().__init__() |
|
self.eps = eps |
|
inner_dim = heads * dim_head |
|
|
|
self.num_landmarks = num_landmarks |
|
self.pinv_iterations = pinv_iterations |
|
|
|
self.heads = heads |
|
self.scale = dim_head ** -0.5 |
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) |
|
|
|
self.to_out = nn.Sequential( |
|
nn.Linear(inner_dim, dim), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
self.residual = residual |
|
if residual: |
|
kernel_size = residual_conv_kernel |
|
padding = residual_conv_kernel // 2 |
|
self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False) |
|
|
|
def forward(self, x, mask = None, return_attn = False): |
|
b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps |
|
|
|
|
|
|
|
remainder = n % m |
|
if remainder > 0: |
|
padding = m - (n % m) |
|
x = F.pad(x, (0, 0, padding, 0), value = 0) |
|
|
|
if exists(mask): |
|
mask = F.pad(mask, (padding, 0), value = False) |
|
|
|
|
|
|
|
q, k, v = self.to_qkv(x).chunk(3, dim = -1) |
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) |
|
|
|
|
|
|
|
if exists(mask): |
|
mask = rearrange(mask, 'b n -> b () n') |
|
q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) |
|
|
|
q = q * self.scale |
|
|
|
|
|
|
|
l = ceil(n / m) |
|
landmark_einops_eq = '... (n l) d -> ... n d' |
|
q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l) |
|
k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l) |
|
|
|
|
|
|
|
divisor = l |
|
if exists(mask): |
|
mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l) |
|
divisor = mask_landmarks_sum[..., None] + eps |
|
mask_landmarks = mask_landmarks_sum > 0 |
|
|
|
|
|
|
|
q_landmarks /= divisor |
|
k_landmarks /= divisor |
|
|
|
|
|
|
|
einops_eq = '... i d, ... j d -> ... i j' |
|
sim1 = einsum(einops_eq, q, k_landmarks) |
|
sim2 = einsum(einops_eq, q_landmarks, k_landmarks) |
|
sim3 = einsum(einops_eq, q_landmarks, k) |
|
|
|
|
|
|
|
if exists(mask): |
|
mask_value = -torch.finfo(q.dtype).max |
|
sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) |
|
sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) |
|
sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) |
|
|
|
|
|
|
|
attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3)) |
|
attn2_inv = moore_penrose_iter_pinv(attn2, iters) |
|
|
|
out = (attn1 @ attn2_inv) @ (attn3 @ v) |
|
|
|
|
|
|
|
if self.residual: |
|
out += self.res_conv(v) |
|
|
|
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)', h = h) |
|
out = self.to_out(out) |
|
out = out[:, -n:] |
|
|
|
if return_attn: |
|
attn = attn1 @ attn2_inv @ attn3 |
|
return out, attn |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NystromBlock(nn.Module): |
|
def __init__(self,dim,dim_ffn, dropout): |
|
super().__init__() |
|
self.Nystrom = NystromAttention( |
|
dim, |
|
dim_head = 64, |
|
heads = 4, |
|
num_landmarks = 32, |
|
pinv_iterations = 3, |
|
residual = True, |
|
residual_conv_kernel = 33, |
|
eps = 1e-8, |
|
dropout = dropout) |
|
self.norm = nn.LayerNorm(dim) |
|
|
|
self.ffn = FeedForward(dim,dim_ffn,dropout) |
|
|
|
def forward(self, x): |
|
res = x |
|
x = self.norm(x) |
|
x = self.Nystrom(x) |
|
x = res + x |
|
res = x |
|
x = self.norm(x) |
|
x = self.ffn(x) |
|
out = x + res |
|
return out |
|
|
|
|
|
|
|
|
|
class ApproximatorGatingUnit(nn.Module): |
|
def __init__(self,d_model,d_ffn,dropout): |
|
super().__init__() |
|
|
|
self.Approx_1 = NystromBlock(d_model,d_ffn,dropout) |
|
self.Approx_2 = NystromBlock(d_model,d_ffn,dropout) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
u, v = x, x |
|
u = self.Approx_1(u) |
|
v = self.Approx_2(v) |
|
out = u * v |
|
return out |
|
|
|
|
|
class ApproximatorBlock(nn.Module): |
|
def __init__(self, d_model, d_ffn,dropout): |
|
super().__init__() |
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
self.agu = ApproximatorGatingUnit(d_model,d_ffn,dropout) |
|
self.ffn = FeedForward(d_model,d_ffn,dropout) |
|
def forward(self, x): |
|
residual = x |
|
x = self.norm(x) |
|
x = self.agu(x) |
|
x = x + residual |
|
residual = x |
|
x = self.norm(x) |
|
x = self.ffn(x) |
|
out = x + residual |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Approximator(nn.Module): |
|
def __init__(self, d_model, d_ffn, num_layers,dropout): |
|
super().__init__() |
|
|
|
self.model = nn.Sequential( |
|
|
|
*[ApproximatorBlock(d_model,d_ffn,dropout) for _ in range(num_layers)], |
|
|
|
|
|
) |
|
|
|
def forward(self, x): |
|
|
|
x = self.model(x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|