#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ This code is from AllenAI's Longformer: https://github.com/allenai/longformer/ """ import torch import torch.nn.functional as F from .diagonaled_mm_tvm import mask_invalid_locations def _skew(x, direction, padding_value): '''Convert diagonals into columns (or columns into diagonals depending on `direction`''' x_padded = F.pad(x, direction, value=padding_value) x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2)) return x_padded def _skew2(x, padding_value): '''shift every row 1 step to right converting columns into diagonals''' # X = B x C x M x L B, C, M, L = x.size() x = F.pad(x, (0, M + 1), value=padding_value) # B x C x M x (L+M+1) x = x.view(B, C, -1) # B x C x ML+MM+M x = x[:, :, :-M] # B x C x ML+MM x = x.view(B, C, M, M + L) # B x C, M x L+M x = x[:, :, :, :-1] return x def _chunk(x, w): '''convert into overlapping chunkings. Chunk size = 2w, overlap size = w''' # non-overlapping chunks of size = 2w x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2)) # use `as_strided` to make the chunks overlap with an overlap size = w chunk_size = list(x.size()) chunk_size[1] = chunk_size[1] * 2 - 1 chunk_stride = list(x.stride()) chunk_stride[1] = chunk_stride[1] // 2 return x.as_strided(size=chunk_size, stride=chunk_stride) def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float): '''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern. This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an overlap of size w''' bsz, seqlen, num_heads, head_dim = q.size() assert seqlen % (w * 2) == 0 assert q.size() == k.size() chunks_count = seqlen // w - 1 # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2 q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) chunk_q = _chunk(q, w) chunk_k = _chunk(k, w) # matrix multipication # bcxd: bsz*num_heads x chunks x 2w x head_dim # bcyd: bsz*num_heads x chunks x 2w x head_dim # bcxy: bsz*num_heads x chunks x 2w x 2w chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) # multiply # convert diagonals into columns diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value) # allocate space for the overall attention matrix where the chunks are compined. The last dimension # has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to # w previous words). The following column is attention score from each word to itself, then # followed by w columns for the upper triangle. diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1)) # copy parts from diagonal_chunk_attn into the compined matrix of attentions # - copying the main diagonal and the upper triangle diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1] diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1] # - copying the lower triangle diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:] diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:] # separate bsz and num_heads dimensions again diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1) mask_invalid_locations(diagonal_attn, w, 1, False) return diagonal_attn def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int): '''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output format from sliding_chunks_matmul_qk''' bsz, seqlen, num_heads, head_dim = v.size() assert seqlen % (w * 2) == 0 assert prob.size()[:3] == v.size()[:3] assert prob.size(3) == 2 * w + 1 chunks_count = seqlen // w - 1 # group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1) # group bsz and num_heads dimensions into one v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) # pad seqlen with w at the beginning of the sequence and another w at the end padded_v = F.pad(v, (0, 0, w, w), value=-1) # chunk padded_v into chunks of size 3w and an overlap of size w chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim) chunk_v_stride = padded_v.stride() chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2] chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride) skewed_prob = _skew2(chunk_prob, padding_value=0) context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v)) return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2) def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor, one_sided_window_size: int, pad_token_id: int): '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention. Input: input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces attention_mask = torch.Tensor(bsz x seqlen): attention mask one_sided_window_size = int: window size on one side of each token pad_token_id = int: tokenizer.pad_token_id Returns (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size ''' w = int(2 * one_sided_window_size) seqlen = input_ids.size(1) padding_len = (w - seqlen % w) % w input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens return input_ids, attention_mask # ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention ========= # This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw # To make this implemenation comparable to "sliding_chunks" set w such that # w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3 # For example, # w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512) # w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510) # Performance: # - Speed: 30% faster than "sliding_chunks" # - Memory: 95% of the memory usage of "sliding_chunks" # The windows are asymmetric where number of attention on each side of a token ranges between w to 2w # while "sliding_chunks" has a symmetric window around each token. def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float): bsz, seqlen, num_heads, head_dim = q.size() assert seqlen % w == 0 assert q.size() == k.size() # chunk seqlen into non-overlapping chunks of size w chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim) chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim) chunk_k_expanded = torch.stack(( F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0), chunk_k, F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0), ), dim=-1) diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w) def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int): bsz, seqlen, num_heads, head_dim = v.size() chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w) chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim) chunk_v_extended = torch.stack(( F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0), chunk_v, F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0), ), dim=-1) context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended)) return context.reshape(bsz, seqlen, num_heads, head_dim)