import torch
import numpy as np
from torch import nn, einsum
from torch.nn import functional as F
from einops.layers.torch import Rearrange
from einops import rearrange, reduce
from math import ceil
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout):
super().__init__() = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.Linear(hidden_dim, dim),
def forward(self, x):
# helper functions
def exists(val):
return val is not None
def moore_penrose_iter_pinv(x, iters = 6):
device = x.device
abs_x = torch.abs(x)
col = abs_x.sum(dim = -1)
row = abs_x.sum(dim = -2)
z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row))
I = torch.eye(x.shape[-1], device = device)
I = rearrange(I, 'i j -> () i j')
for _ in range(iters):
xz = x @ z
z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz)))))
return z
# main attention class
class NystromAttention(nn.Module):
def __init__(
dim_head = 64,
heads = 8,
num_landmarks = 256,
pinv_iterations = 6,
residual = True,
residual_conv_kernel = 33,
eps = 1e-8,
dropout = 0.
self.eps = eps
inner_dim = heads * dim_head
self.num_landmarks = num_landmarks
self.pinv_iterations = pinv_iterations
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
self.residual = residual
if residual:
kernel_size = residual_conv_kernel
padding = residual_conv_kernel // 2
self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False)
def forward(self, x, mask = None, return_attn = False):
b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps
# pad so that sequence can be evenly divided into m landmarks
remainder = n % m
if remainder > 0:
padding = m - (n % m)
x = F.pad(x, (0, 0, padding, 0), value = 0)
if exists(mask):
mask = F.pad(mask, (padding, 0), value = False)
# derive query, keys, values
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# set masked positions to 0 in queries, keys, values
if exists(mask):
mask = rearrange(mask, 'b n -> b () n')
q, k, v = map(lambda t: t * mask[..., None], (q, k, v))
q = q * self.scale
# generate landmarks by sum reduction, and then calculate mean using the mask
l = ceil(n / m)
landmark_einops_eq = '... (n l) d -> ... n d'
q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l)
k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l)
# calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean
divisor = l
if exists(mask):
mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l)
divisor = mask_landmarks_sum[..., None] + eps
mask_landmarks = mask_landmarks_sum > 0
# masked mean (if mask exists)
q_landmarks /= divisor
k_landmarks /= divisor
# similarities
einops_eq = '... i d, ... j d -> ... i j'
sim1 = einsum(einops_eq, q, k_landmarks)
sim2 = einsum(einops_eq, q_landmarks, k_landmarks)
sim3 = einsum(einops_eq, q_landmarks, k)
# masking
if exists(mask):
mask_value = -torch.finfo(q.dtype).max
sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value)
sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value)
sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value)
# eq (15) in the paper and aggregate values
attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3))
attn2_inv = moore_penrose_iter_pinv(attn2, iters)
out = (attn1 @ attn2_inv) @ (attn3 @ v)
# add depth-wise conv residual of values
if self.residual:
out += self.res_conv(v)
# merge and combine heads
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
out = self.to_out(out)
out = out[:, -n:]
if return_attn:
attn = attn1 @ attn2_inv @ attn3
return out, attn
return out
class NystromBlock(nn.Module):
def __init__(self,dim,dim_ffn, dropout):
self.Nystrom = NystromAttention(
dim_head = 64,
heads = 4,
num_landmarks = 32,
pinv_iterations = 3,
residual = True,
residual_conv_kernel = 33,
eps = 1e-8,
dropout = dropout)
self.norm = nn.LayerNorm(dim)
self.ffn = FeedForward(dim,dim_ffn,dropout)
def forward(self, x):
res = x
x = self.norm(x)
x = self.Nystrom(x)
x = res + x
res = x
x = self.norm(x)
x = self.ffn(x)
out = x + res
return out
class ApproximatorGatingUnit(nn.Module):
def __init__(self,d_model,d_ffn,dropout):
#self.proj = nn.Linear(d_model, d_model)
self.Approx_1 = NystromBlock(d_model,d_ffn,dropout)
self.Approx_2 = NystromBlock(d_model,d_ffn,dropout)
def forward(self, x):
u, v = x, x
u = self.Approx_1(u)
v = self.Approx_2(v)
out = u * v
return out
class ApproximatorBlock(nn.Module):
def __init__(self, d_model, d_ffn,dropout):
self.norm = nn.LayerNorm(d_model)
self.agu = ApproximatorGatingUnit(d_model,d_ffn,dropout)
self.ffn = FeedForward(d_model,d_ffn,dropout)
def forward(self, x):
residual = x
x = self.norm(x)
x = self.agu(x)
x = x + residual
residual = x
x = self.norm(x)
x = self.ffn(x)
out = x + residual
return out
class Approximator(nn.Module):
def __init__(self, d_model, d_ffn, num_layers,dropout):
self.model = nn.Sequential(
*[ApproximatorBlock(d_model,d_ffn,dropout) for _ in range(num_layers)],
def forward(self, x):
x = self.model(x)
return x