import torch from torch import nn, einsum import torch.nn.functional as F from einops import rearrange, repeat # helpers def exists(val): return val is not None def max_neg_value(t): return -torch.finfo(t.dtype).max # classes class Attention(nn.Module): def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.): super().__init__() inner_dim = dim_head * heads self.heads = heads self.seq_len = seq_len self.scale = dim_head ** -0.5 self.causal = causal self.attn_drop = nn.Dropout(attn_dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(resid_dropout) ) def forward(self, x): h, device = self.heads, x.device qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) q = q * self.scale dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) mask_value = max_neg_value(dots) if self.causal: i, j = dots.shape[-2:] mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() dots.masked_fill_(mask, mask_value) attn = torch.softmax(dots, dim=-1) attn = self.attn_drop(attn) out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) return out # sparse axial causal attention class SparseAxialCausalAttention(nn.Module): def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head = 64, attn_dropout = 0., resid_dropout = 0.): super().__init__() assert axis in {0, 1}, 'axis must be either 0 (along height) or 1 (along width)' self.axis = axis inner_dim = dim_head * heads self.seq_len = seq_len self.heads = heads self.scale = dim_head ** -0.5 self.image_size = image_size self.attn_drop = nn.Dropout(attn_dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(resid_dropout) ) def forward(self, x): b, n, _, h, img_size, axis, seq_len, device = *x.shape, self.heads, self.image_size, self.axis, self.seq_len, x.device img_seq_len = img_size ** 2 text_len = seq_len + 1 - img_seq_len # padding padding = seq_len - n + 1 mask = torch.ones(b, text_len, device = device).bool() x = F.pad(x, (0, 0, 0, padding), value = 0) mask = mask[:, :text_len] # derive queries / keys / values qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv) # print(self.scale) q = q * self.scale ((q_text, q_img), (k_text, k_img), (v_text, v_img)) = map(lambda t: (t[:, :-img_seq_len], t[:, -img_seq_len:]), (q, k, v)) # text attention dots_text = einsum('b i d, b j d -> b i j', q_text, k_text) mask_value = max_neg_value(dots_text) i, j = dots_text.shape[-2:] text_causal_mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() dots_text.masked_fill_(text_causal_mask, mask_value) attn_text = torch.softmax(dots_text, dim = -1) # attention dropout attn_text = self.attn_drop(attn_text) out_text = einsum('b i j, b j d -> b i d', attn_text, v_text) # image attention split_axis_einops = 'b (h w) c -> b h w c' if axis == 0 else 'b (h w) c -> b w h c' merge_axis_einops = 'b x n d -> b (x n) d' if axis == 0 else 'b x n d -> b (n x) d' # split out axis q_img, k_img, v_img = map(lambda t: rearrange(t, split_axis_einops, h = img_size), (q_img, k_img, v_img)) # similarity dots_image_to_image = einsum('b x i d, b x j d -> b x i j', q_img, k_img) dots_image_to_text = einsum('b x i d, b j d -> b x i j', q_img, k_text) dots = torch.cat((dots_image_to_text, dots_image_to_image), dim = -1) # mask so image has full attention to text, but causal along axis bh, x, i, j = dots.shape causal_mask = torch.ones(i, img_size, device = device).triu_(img_size - i + 1).bool() causal_mask = repeat(causal_mask, 'i j -> b x i j', b = bh, x = x) mask = repeat(mask, 'b j -> (b h) x i j', h = h, x = x, i = i) mask = torch.cat((~mask, causal_mask), dim = -1) dots.masked_fill_(mask, mask_value) # attention. attn = torch.softmax(dots, dim = -1) # attention dropout attn = self.attn_drop(attn) # aggregate attn_image_to_text, attn_image_to_image = attn[..., :text_len], attn[..., text_len:] out_image_to_image = einsum('b x i j, b x j d -> b x i d', attn_image_to_image, v_img) out_image_to_text = einsum('b x i j, b j d -> b x i d', attn_image_to_text, v_text) out_image = out_image_to_image + out_image_to_text # merge back axis out_image = rearrange(out_image, merge_axis_einops, x = img_size) # combine attended values for both text and image out = torch.cat((out_text, out_image), dim = 1) out = rearrange(out, '(b h) n d -> b n (h d)', h = h) out = self.to_out(out) return out[:, :n]