File size: 8,365 Bytes
340c8dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This code is from AllenAI's Longformer:
https://github.com/allenai/longformer/
"""
import torch
import torch.nn.functional as F
from .diagonaled_mm_tvm import mask_invalid_locations
def _skew(x, direction, padding_value):
'''Convert diagonals into columns (or columns into diagonals depending on `direction`'''
x_padded = F.pad(x, direction, value=padding_value)
x_padded = x_padded.view(*x_padded.size()[:-2], x_padded.size(-1), x_padded.size(-2))
return x_padded
def _skew2(x, padding_value):
'''shift every row 1 step to right converting columns into diagonals'''
# X = B x C x M x L
B, C, M, L = x.size()
x = F.pad(x, (0, M + 1), value=padding_value) # B x C x M x (L+M+1)
x = x.view(B, C, -1) # B x C x ML+MM+M
x = x[:, :, :-M] # B x C x ML+MM
x = x.view(B, C, M, M + L) # B x C, M x L+M
x = x[:, :, :, :-1]
return x
def _chunk(x, w):
'''convert into overlapping chunkings. Chunk size = 2w, overlap size = w'''
# non-overlapping chunks of size = 2w
x = x.view(x.size(0), x.size(1) // (w * 2), w * 2, x.size(2))
# use `as_strided` to make the chunks overlap with an overlap size = w
chunk_size = list(x.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(x.stride())
chunk_stride[1] = chunk_stride[1] // 2
return x.as_strided(size=chunk_size, stride=chunk_stride)
def sliding_chunks_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
'''Matrix multiplicatio of query x key tensors using with a sliding window attention pattern.
This implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer)
with an overlap of size w'''
bsz, seqlen, num_heads, head_dim = q.size()
assert seqlen % (w * 2) == 0
assert q.size() == k.size()
chunks_count = seqlen // w - 1
# group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size w * 2
q = q.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
k = k.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
chunk_q = _chunk(q, w)
chunk_k = _chunk(k, w)
# matrix multipication
# bcxd: bsz*num_heads x chunks x 2w x head_dim
# bcyd: bsz*num_heads x chunks x 2w x head_dim
# bcxy: bsz*num_heads x chunks x 2w x 2w
chunk_attn = torch.einsum('bcxd,bcyd->bcxy', (chunk_q, chunk_k)) # multiply
# convert diagonals into columns
diagonal_chunk_attn = _skew(chunk_attn, direction=(0, 0, 0, 1), padding_value=padding_value)
# allocate space for the overall attention matrix where the chunks are compined. The last dimension
# has (w * 2 + 1) columns. The first (w) columns are the w lower triangles (attention from a word to
# w previous words). The following column is attention score from each word to itself, then
# followed by w columns for the upper triangle.
diagonal_attn = diagonal_chunk_attn.new_empty((bsz * num_heads, chunks_count + 1, w, w * 2 + 1))
# copy parts from diagonal_chunk_attn into the compined matrix of attentions
# - copying the main diagonal and the upper triangle
diagonal_attn[:, :-1, :, w:] = diagonal_chunk_attn[:, :, :w, :w + 1]
diagonal_attn[:, -1, :, w:] = diagonal_chunk_attn[:, -1, w:, :w + 1]
# - copying the lower triangle
diagonal_attn[:, 1:, :, :w] = diagonal_chunk_attn[:, :, - (w + 1):-1, w + 1:]
diagonal_attn[:, 0, 1:w, 1:w] = diagonal_chunk_attn[:, 0, :w - 1, 1 - w:]
# separate bsz and num_heads dimensions again
diagonal_attn = diagonal_attn.view(bsz, num_heads, seqlen, 2 * w + 1).transpose(2, 1)
mask_invalid_locations(diagonal_attn, w, 1, False)
return diagonal_attn
def sliding_chunks_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
'''Same as sliding_chunks_matmul_qk but for prob and value tensors. It is expecting the same output
format from sliding_chunks_matmul_qk'''
bsz, seqlen, num_heads, head_dim = v.size()
assert seqlen % (w * 2) == 0
assert prob.size()[:3] == v.size()[:3]
assert prob.size(3) == 2 * w + 1
chunks_count = seqlen // w - 1
# group bsz and num_heads dimensions into one, then chunk seqlen into chunks of size 2w
chunk_prob = prob.transpose(1, 2).reshape(bsz * num_heads, seqlen // w, w, 2 * w + 1)
# group bsz and num_heads dimensions into one
v = v.transpose(1, 2).reshape(bsz * num_heads, seqlen, head_dim)
# pad seqlen with w at the beginning of the sequence and another w at the end
padded_v = F.pad(v, (0, 0, w, w), value=-1)
# chunk padded_v into chunks of size 3w and an overlap of size w
chunk_v_size = (bsz * num_heads, chunks_count + 1, 3 * w, head_dim)
chunk_v_stride = padded_v.stride()
chunk_v_stride = chunk_v_stride[0], w * chunk_v_stride[1], chunk_v_stride[1], chunk_v_stride[2]
chunk_v = padded_v.as_strided(size=chunk_v_size, stride=chunk_v_stride)
skewed_prob = _skew2(chunk_prob, padding_value=0)
context = torch.einsum('bcwd,bcdh->bcwh', (skewed_prob, chunk_v))
return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2)
def pad_to_window_size(input_ids: torch.Tensor, attention_mask: torch.Tensor,
one_sided_window_size: int, pad_token_id: int):
'''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer selfattention.
Input:
input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
attention_mask = torch.Tensor(bsz x seqlen): attention mask
one_sided_window_size = int: window size on one side of each token
pad_token_id = int: tokenizer.pad_token_id
Returns
(input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
'''
w = int(2 * one_sided_window_size)
seqlen = input_ids.size(1)
padding_len = (w - seqlen % w) % w
input_ids = F.pad(input_ids, (0, padding_len), value=pad_token_id)
attention_mask = F.pad(attention_mask, (0, padding_len), value=False) # no attention on the padding tokens
return input_ids, attention_mask
# ========= "sliding_chunks_no_overlap": alternative implemenation of the sliding window attention =========
# This implementation uses non-overlapping chunks (or blocks) of size `w` with number of local attention = 3xw
# To make this implemenation comparable to "sliding_chunks" set w such that
# w_of_sliding_chunks_no_overlap = w_of_sliding_chunks * 2 / 3
# For example,
# w_of_sliding_chunks = 256 (this is one sided. Total attention size = 512)
# w_of_sliding_chunks_no_overlap = 170 (Total attention size = 510)
# Performance:
# - Speed: 30% faster than "sliding_chunks"
# - Memory: 95% of the memory usage of "sliding_chunks"
# The windows are asymmetric where number of attention on each side of a token ranges between w to 2w
# while "sliding_chunks" has a symmetric window around each token.
def sliding_chunks_no_overlap_matmul_qk(q: torch.Tensor, k: torch.Tensor, w: int, padding_value: float):
bsz, seqlen, num_heads, head_dim = q.size()
assert seqlen % w == 0
assert q.size() == k.size()
# chunk seqlen into non-overlapping chunks of size w
chunk_q = q.view(bsz, seqlen // w, w, num_heads, head_dim)
chunk_k = k.view(bsz, seqlen // w, w, num_heads, head_dim)
chunk_k_expanded = torch.stack((
F.pad(chunk_k[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
chunk_k,
F.pad(chunk_k[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
), dim=-1)
diagonal_attn = torch.einsum('bcxhd,bcyhde->bcxhey', (chunk_q, chunk_k_expanded)) # multiply
return diagonal_attn.reshape(bsz, seqlen, num_heads, 3 * w)
def sliding_chunks_no_overlap_matmul_pv(prob: torch.Tensor, v: torch.Tensor, w: int):
bsz, seqlen, num_heads, head_dim = v.size()
chunk_prob = prob.view(bsz, seqlen // w, w, num_heads, 3, w)
chunk_v = v.view(bsz, seqlen // w, w, num_heads, head_dim)
chunk_v_extended = torch.stack((
F.pad(chunk_v[:, :-1], (0, 0, 0, 0, 0, 0, 1, 0), value=0.0),
chunk_v,
F.pad(chunk_v[:, 1:], (0, 0, 0, 0, 0, 0, 0, 1), value=0.0),
), dim=-1)
context = torch.einsum('bcwhpd,bcdhep->bcwhe', (chunk_prob, chunk_v_extended))
return context.reshape(bsz, seqlen, num_heads, head_dim)
|