import torch import numpy as np from torch import nn, einsum from torch.nn import functional as F from einops.layers.torch import Rearrange from einops import rearrange, reduce from math import ceil class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) # helper functions def exists(val): return val is not None def moore_penrose_iter_pinv(x, iters = 6): device = x.device abs_x = torch.abs(x) col = abs_x.sum(dim = -1) row = abs_x.sum(dim = -2) z = rearrange(x, '... i j -> ... j i') / (torch.max(col) * torch.max(row)) I = torch.eye(x.shape[-1], device = device) I = rearrange(I, 'i j -> () i j') for _ in range(iters): xz = x @ z z = 0.25 * z @ (13 * I - (xz @ (15 * I - (xz @ (7 * I - xz))))) return z # main attention class class NystromAttention(nn.Module): def __init__( self, dim, dim_head = 64, heads = 8, num_landmarks = 256, pinv_iterations = 6, residual = True, residual_conv_kernel = 33, eps = 1e-8, dropout = 0. ): super().__init__() self.eps = eps inner_dim = heads * dim_head self.num_landmarks = num_landmarks self.pinv_iterations = pinv_iterations self.heads = heads self.scale = dim_head ** -0.5 self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) self.residual = residual if residual: kernel_size = residual_conv_kernel padding = residual_conv_kernel // 2 self.res_conv = nn.Conv2d(heads, heads, (kernel_size, 1), padding = (padding, 0), groups = heads, bias = False) def forward(self, x, mask = None, return_attn = False): b, n, _, h, m, iters, eps = *x.shape, self.heads, self.num_landmarks, self.pinv_iterations, self.eps # pad so that sequence can be evenly divided into m landmarks remainder = n % m if remainder > 0: padding = m - (n % m) x = F.pad(x, (0, 0, padding, 0), value = 0) if exists(mask): mask = F.pad(mask, (padding, 0), value = False) # derive query, keys, values q, k, v = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # set masked positions to 0 in queries, keys, values if exists(mask): mask = rearrange(mask, 'b n -> b () n') q, k, v = map(lambda t: t * mask[..., None], (q, k, v)) q = q * self.scale # generate landmarks by sum reduction, and then calculate mean using the mask l = ceil(n / m) landmark_einops_eq = '... (n l) d -> ... n d' q_landmarks = reduce(q, landmark_einops_eq, 'sum', l = l) k_landmarks = reduce(k, landmark_einops_eq, 'sum', l = l) # calculate landmark mask, and also get sum of non-masked elements in preparation for masked mean divisor = l if exists(mask): mask_landmarks_sum = reduce(mask, '... (n l) -> ... n', 'sum', l = l) divisor = mask_landmarks_sum[..., None] + eps mask_landmarks = mask_landmarks_sum > 0 # masked mean (if mask exists) q_landmarks /= divisor k_landmarks /= divisor # similarities einops_eq = '... i d, ... j d -> ... i j' sim1 = einsum(einops_eq, q, k_landmarks) sim2 = einsum(einops_eq, q_landmarks, k_landmarks) sim3 = einsum(einops_eq, q_landmarks, k) # masking if exists(mask): mask_value = -torch.finfo(q.dtype).max sim1.masked_fill_(~(mask[..., None] * mask_landmarks[..., None, :]), mask_value) sim2.masked_fill_(~(mask_landmarks[..., None] * mask_landmarks[..., None, :]), mask_value) sim3.masked_fill_(~(mask_landmarks[..., None] * mask[..., None, :]), mask_value) # eq (15) in the paper and aggregate values attn1, attn2, attn3 = map(lambda t: t.softmax(dim = -1), (sim1, sim2, sim3)) attn2_inv = moore_penrose_iter_pinv(attn2, iters) out = (attn1 @ attn2_inv) @ (attn3 @ v) # add depth-wise conv residual of values if self.residual: out += self.res_conv(v) # merge and combine heads out = rearrange(out, 'b h n d -> b n (h d)', h = h) out = self.to_out(out) out = out[:, -n:] if return_attn: attn = attn1 @ attn2_inv @ attn3 return out, attn return out class NystromBlock(nn.Module): def __init__(self,dim,dim_ffn, dropout): super().__init__() self.Nystrom = NystromAttention( dim, dim_head = 64, heads = 4, num_landmarks = 32, pinv_iterations = 3, residual = True, residual_conv_kernel = 33, eps = 1e-8, dropout = dropout) self.norm = nn.LayerNorm(dim) self.ffn = FeedForward(dim,dim_ffn,dropout) def forward(self, x): res = x x = self.norm(x) x = self.Nystrom(x) x = res + x res = x x = self.norm(x) x = self.ffn(x) out = x + res return out class ApproximatorGatingUnit(nn.Module): def __init__(self,d_model,d_ffn,dropout): super().__init__() #self.proj = nn.Linear(d_model, d_model) self.Approx_1 = NystromBlock(d_model,d_ffn,dropout) self.Approx_2 = NystromBlock(d_model,d_ffn,dropout) def forward(self, x): u, v = x, x u = self.Approx_1(u) v = self.Approx_2(v) out = u * v return out class ApproximatorBlock(nn.Module): def __init__(self, d_model, d_ffn,dropout): super().__init__() self.norm = nn.LayerNorm(d_model) self.agu = ApproximatorGatingUnit(d_model,d_ffn,dropout) self.ffn = FeedForward(d_model,d_ffn,dropout) def forward(self, x): residual = x x = self.norm(x) x = self.agu(x) x = x + residual residual = x x = self.norm(x) x = self.ffn(x) out = x + residual return out class Approximator(nn.Module): def __init__(self, d_model, d_ffn, num_layers,dropout): super().__init__() self.model = nn.Sequential( *[ApproximatorBlock(d_model,d_ffn,dropout) for _ in range(num_layers)], ) def forward(self, x): x = self.model(x) return x