|
|
|
|
|
|
|
""" |
|
|
|
This code is from AllenAI's Longformer: |
|
https://github.com/allenai/longformer/ |
|
|
|
""" |
|
import torch |
|
import torch.nn.functional as F |
|
from .diagonaled_mm_tvm import mask_invalid_locations |
|
|
|
|
|
def _skew(x, direction, padding_value): |
|
'''Convert diagonals into columns (or columns into diagonals depending on `direction`''' |
|
x_padded = F.pad(x, direction, value=padding_value) |
|
x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2)) |
|
return x_padded |
|
|
|
|
|
def _skew2(x, padding_value): |
|
'''shift every row 1 step to right converting columns into diagonals''' |
|
|
|
B, C, M, L = x.size() |
|
x = F.pad(x, (0, M + 1), value=padding_value) |
|
x = x.view(B, C, -1) |
|
x = x[:, :, :-M] |
|
x = x.view(B, C, M, M + L) |
|
x = x[:, :, :, :-1] |
|
return x |
|
|
|
|
|
def _chunk(x, w): |
|
'''convert into overlapping chunkings. Chunk size = 2w, overlap size = w''' |
|
|
|
|
|
x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2)) |
|
|
|
|
|
chunk_size = list(x.size()) |
|
chunk_size[1] = chunk_size[1] * 2 - 1 |
|
|
|
chunk_stride = list(x.stride()) |
|
chunk_stride[1] = chunk_stride[1] // 2 |
|
return x.as_strided(size=chunk_size, stride=chunk_stride) |
|
|
|
|
|
def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float): |
|
'''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern. |
|
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) |
|
with an overlap of size w''' |
|
bsz, seqlen, num_heads, head_dim = q.size() |
|
assert seqlen % (w * 2) == 0 |
|
assert q.size() == k.size() |
|
|
|
chunks_count = seqlen // w - 1 |
|
|
|
|
|
q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) |
|
k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) |
|
|
|
chunk_q = _chunk(q, w) |
|
chunk_k = _chunk(k, w) |
|
|
|
|
|
|
|
|
|
|
|
chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) |
|
|
|
|
|
diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1)) |
|
|
|
|
|
|
|
diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1] |
|
diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1] |
|
|
|
diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:] |
|
diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:] |
|
|
|
|
|
diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1) |
|
|
|
mask_invalid_locations(diagonal_attn, w, 1, False) |
|
return diagonal_attn |
|
|
|
|
|
def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int): |
|
'''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output |
|
format from sliding_chunks_matmul_qk''' |
|
bsz, seqlen, num_heads, head_dim = v.size() |
|
assert seqlen % (w * 2) == 0 |
|
assert prob.size()[:3] == v.size()[:3] |
|
assert prob.size(3) == 2 * w + 1 |
|
chunks_count = seqlen // w - 1 |
|
|
|
chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1) |
|
|
|
|
|
v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim) |
|
|
|
|
|
padded_v = F.pad(v, (0, 0, w, w), value=-1) |
|
|
|
|
|
chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim) |
|
chunk_v_stride = padded_v.stride() |
|
chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2] |
|
chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride) |
|
|
|
skewed_prob = _skew2(chunk_prob, padding_value=0) |
|
|
|
context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v)) |
|
return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2) |
|
|
|
|
|
def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor, |
|
one_sided_window_size: int, pad_token_id: int): |
|
'''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention. |
|
Input: |
|
input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces |
|
attention_mask = torch.Tensor(bsz x seqlen): attention mask |
|
one_sided_window_size = int: window size on one side of each token |
|
pad_token_id = int: tokenizer.pad_token_id |
|
Returns |
|
(input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size |
|
''' |
|
w = int(2 * one_sided_window_size) |
|
seqlen = input_ids.size(1) |
|
padding_len = (w - seqlen % w) % w |
|
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id) |
|
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) |
|
return input_ids, attention_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float): |
|
bsz, seqlen, num_heads, head_dim = q.size() |
|
assert seqlen % w == 0 |
|
assert q.size() == k.size() |
|
|
|
chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim) |
|
chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim) |
|
chunk_k_expanded = torch.stack(( |
|
F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0), |
|
chunk_k, |
|
F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0), |
|
), dim=-1) |
|
diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) |
|
return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w) |
|
|
|
|
|
def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int): |
|
bsz, seqlen, num_heads, head_dim = v.size() |
|
chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w) |
|
chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim) |
|
chunk_v_extended = torch.stack(( |
|
F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0), |
|
chunk_v, |
|
F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0), |
|
), dim=-1) |
|
context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended)) |
|
return context.reshape(bsz, seqlen, num_heads, head_dim) |
|
|