File size: 17,580 Bytes
340c8dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This code is from AllenAI's Longformer:
https://github.com/allenai/longformer/
"""
from typing import Union
from functools import lru_cache
import torch
import os.path
class DiagonaledMM(torch.autograd.Function):
'''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
this function from PyTorch
'''
function_dict = {} # save a list of functions, each has a different set of parameters
@staticmethod
def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 4, b2: int = 16):
'''Compiles a tvm function that computes diagonal_mm
args:
dtype: str in ['float64', 'float32', 'float16']
device: str in ['cpu' or 'cuda']
b0, b1, b2: size of tensor tiles. Very important for good performance
'''
import tvm # import the full tvm library here for compilation. Don't import at the top of the file in case we don't need to compile
from tvm.contrib import nvcc
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') # use old arch for this to work on old GPUs
return ptx
assert dtype in ['float16', 'float32', 'float64']
assert device in ['cpu', 'cuda']
device = None if device == 'cpu' else device
tgt_host="llvm"
b = tvm.var('b') # batch size
n = tvm.var('n') # sequence length
h = tvm.var('h') # number of heads
m = tvm.var('m') # hidden dimension
w = tvm.var('w') # window size
w_upper = tvm.var('w_upper') # window size to the right of the word. Should be `0` or `w`
padding = tvm.var('padding') # padding
transpose_t1 = tvm.var('transpose_t1') # t1 should be transposed
t1d3 = tvm.var('t1d3') # last dimension of t1
t3d3 = tvm.var('t3d3') # last dimension of t3 (the result tensor)
X = tvm.placeholder((b, n, h, t1d3), name='X', dtype=dtype) # first tensor
Y = tvm.placeholder((b, n, h, m), name='Y', dtype=dtype) # second tensor
k = tvm.reduce_axis((0, t1d3), name='k') # dimension to sum over
D = tvm.placeholder((h), name='D', dtype='int') # dilation per head
output_shape = (b, n, h, t3d3) # shape of the result tensor
algorithm = lambda l, i, q, j: tvm.sum(
tvm.if_then_else(
t3d3 == m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
tvm.if_then_else(
transpose_t1 == 0,
tvm.if_then_else(
tvm.all(
i + D[q] * (k - w) >= 0,
i + D[q] * (k - w) < n,
),
X[l, i, q, k] * Y[l, i + D[q] * (k - w), q, j], # t1 is diagonaled
padding
),
tvm.if_then_else(
tvm.all(
i + D[q] * (k - w_upper) >= 0, # `w_upper` to handle the case `autoregressive=True`
i + D[q] * (k - w_upper) < n,
),
X[l, i + D[q] * (k - w_upper), q, (w_upper + w) - k] * Y[l, i + D[q] * (k - w_upper), q, j], # # t1 is diagonaled and should be transposed
padding
),
),
tvm.if_then_else(
tvm.all(
i + D[q] * (j - w) >= 0,
i + D[q] * (j - w) < n,
),
X[l, i, q, k] * Y[l, i + D[q] * (j - w), q, k], # t1 is not diagonaled, but the output tensor is going to be
padding
)
), axis=k)
Z = tvm.compute(output_shape, algorithm, name='Z') # automatically generate cuda code
s = tvm.create_schedule(Z.op)
print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
# split long axis into smaller chunks and assing each one to a separate GPU thread/block
ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0)
ZF = s.rfactor(Z, ki)
j_outer, j_inner = s[Z].split(s[Z].op.axis[-1], factor=b1)
i_outer, i_inner = s[Z].split(s[Z].op.axis[1], factor=b2)
s[Z].bind(j_outer, tvm.thread_axis("blockIdx.x"))
s[Z].bind(j_inner, tvm.thread_axis("threadIdx.y"))
s[Z].bind(i_outer, tvm.thread_axis("blockIdx.y"))
s[Z].bind(i_inner, tvm.thread_axis("threadIdx.z"))
tx = tvm.thread_axis("threadIdx.x")
s[Z].bind(s[Z].op.reduce_axis[0], tx)
s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0])
s[Z].set_store_predicate(tx.var.equal(0))
print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
# compiling the automatically generated cuda code
diagonaled_mm = tvm.build(s, [X, Y, Z, D, w, w_upper, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='diagonaled_mm')
return diagonaled_mm
@staticmethod
def _get_lib_filename(dtype: str, device: str):
base_filename = 'longformer/lib/lib_diagonaled_mm'
return '{}_{}_{}.so'.format(base_filename, dtype, device)
@staticmethod
def _save_compiled_function(f, dtype: str, device: str):
if not os.path.exists('longformer/lib/'):
os.makedirs('longformer/lib/')
f.export_library(DiagonaledMM._get_lib_filename(dtype, device))
@staticmethod
def _load_compiled_function(dtype: str, device: str):
from tvm.module import load # this can be the small runtime python library, and doesn't need to be the whole thing
filename = DiagonaledMM._get_lib_filename(dtype, device)
current_dir = os.path.dirname(os.path.abspath(__file__))
potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../']
for potential_dir in potential_dirs:
filepath = '{}{}'.format(potential_dir, filename)
if os.path.isfile(filepath):
print('Loading tvm binary from: {}'.format(filepath))
return load(filepath)
return None
@staticmethod
def _get_function(dtype: str, device: str):
'''Loads the function from the disk or compile it'''
# A list of arguments that define the function
args = (dtype, device)
if args not in DiagonaledMM.function_dict:
diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device) # try to load from disk
if not diagonaled_mm:
print('Tvm binary not found. Compiling ...')
diagonaled_mm = DiagonaledMM._compile_function(dtype, device) # compile
DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device) # save to disk
# convert the tvm function into a pytorch function
from tvm.contrib import dlpack
diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm) # wrap it as a pytorch function
# save the function into a dictionary to be reused
DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch # save it in a dictionary for next time
return DiagonaledMM.function_dict[args]
@staticmethod
def _diagonaled_mm(t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int],
is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0,
autoregressive: bool = False):
'''Calls the compiled function after checking the input format. This function is called in three different modes.
t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores
t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context
t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of
the calculations in the backward pass.
'''
dtype = str(t1.dtype).split('.')[1]
device = t1.device.type
assert len(t1.shape) == 4
assert len(t1.shape) == len(t2.shape)
assert t1.shape[:3] == t2.shape[:3]
if isinstance(d, int): # if d is an integer, replace it with a tensor of the same length
# as number of heads, and it is filled with the same dilation value
d = t1.new_full(size=(t1.shape[2],), fill_value=d, dtype=torch.int, requires_grad=False)
assert len(d.shape) == 1
assert d.shape[0] == t1.shape[2] # number of dilation scores should match number of heads
b = t1.shape[0] # batch size
n = t1.shape[1] # sequence length
h = t1.shape[2] # number of heads
m = t2.shape[3] # hidden dimension
w_upper = 0 if autoregressive else w
c = w_upper + w + 1 # number of diagonals
if is_t1_diagonaled:
assert t1.shape[3] == c
r = t1.new_empty(b, n, h, m) # allocate spase for the result tensor
else:
assert not transpose_t1
assert t1.shape[3] == m
r = t1.new_empty(b, n, h, c) # allocate spase for the result tensor
# gets function from memory, from disk or compiles it from scratch
_diagonaled_mm_function = DiagonaledMM._get_function(dtype=dtype, device=device)
# The last argument to this function is a little hacky. It is the size of the last dimension of the result tensor
# We use it as a proxy to tell if t1_is_diagonaled or not (if t1 is diagonaled, result is not, and vice versa).
# The second reason is that the lambda expression in `_compile_function` is easier to express when the shape
# of the output is known
# This functions computes diagonal_mm then saves the result in `r`
if m == c:
# FIXME
print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}')
assert False
_diagonaled_mm_function(t1, t2, r, d, w, w_upper, padding, transpose_t1, m if is_t1_diagonaled else c)
return r
@staticmethod
def _prepare_tensors(t):
'''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter.
TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`.
Here's an example to reporduce this issue:
import torch
print(torch.randn(1, 10).stride())
> (10, 1)
print(torch.randn(10, 1).t().contiguous().stride())
> (1, 1) # expected it to be (10, 1) as above
print(torch.randn(10, 2).t().contiguous().stride())
> (10, 1) # but gets the expected stride if the first dimension is > 1
'''
assert t.is_contiguous()
t_stride = list(t.stride())
t_size = list(t.size())
# Fix wrong stride information for the first dimension. This occures when batch_size=1
if t_size[0] == 1 and t_stride[0] == t_stride[1]:
# In this case, the stride of the first dimension should be the product
# of the sizes of all other dimensions
t_stride[0] = t_size[1] * t_size[2] * t_size[3]
t = t.as_strided(size=t_size, stride=t_stride)
return t
min_seq_len = 16 # unexpected output if seq_len < 16
@staticmethod
def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], is_t1_diagonaled: bool = False, padding: int = 0, autoregressive: bool = False) -> torch.Tensor:
'''Compuates diagonal_mm of t1 and t2.
args:
t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals).
t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`)
t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled
tensor, e.g. `key_layer` or `value_layer`
w: int = window size; number of attentions on each side of the word
d: torch.Tensor or int = dilation of attentions per attention head. If int, the same dilation value will be used for all
heads. If torch.Tensor, it should be 1D of lenth=number of attention heads
is_t1_diagonaled: is t1 a diagonaled or a regular tensor
padding: the padding value to use when accessing invalid locations. This is mainly useful when the padding
needs to be a very large negative value (to compute softmax of attentions). For other usecases,
please use zero padding.
autoregressive: if true, return only the lower triangle
returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals)
if t1 is diagonaed, result is non-diagonaled, and vice versa
'''
batch_size, seq_len, num_attention_heads, hidden_size = t1.size()
assert seq_len >= DiagonaledMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(DiagonaledMM.min_seq_len) # FIXME
ctx.save_for_backward(t1, t2)
ctx.w = w
ctx.d = d
ctx.is_t1_diagonaled = is_t1_diagonaled
ctx.autoregressive = autoregressive
t1 = DiagonaledMM._prepare_tensors(t1)
t2 = DiagonaledMM._prepare_tensors(t2)
# output = t1.mm(t2) # what would have been called if this was a regular matmul
output = DiagonaledMM._diagonaled_mm(t1, t2, w, d, is_t1_diagonaled=is_t1_diagonaled, padding=padding, autoregressive=autoregressive)
return output
@staticmethod
def backward(ctx, grad_output):
t1, t2 = ctx.saved_tensors
w = ctx.w
d = ctx.d
is_t1_diagonaled = ctx.is_t1_diagonaled
autoregressive = ctx.autoregressive
if not grad_output.is_contiguous():
grad_output = grad_output.contiguous() # tvm requires all input tensors to be contiguous
grad_output = DiagonaledMM._prepare_tensors(grad_output)
t1 = DiagonaledMM._prepare_tensors(t1)
t2 = DiagonaledMM._prepare_tensors(t2)
# http://cs231n.github.io/optimization-2/
# https://pytorch.org/docs/master/notes/extending.html
# grad_t1 = grad_output.mm(t2) # what would have been called if this was a regular matmul
grad_t1 = DiagonaledMM._diagonaled_mm(grad_output, t2, w, d, is_t1_diagonaled=not is_t1_diagonaled, autoregressive=autoregressive)
# grad_t2 = grad_output.t().mm(t1) # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T`
if is_t1_diagonaled:
grad_t2 = DiagonaledMM._diagonaled_mm(t1, grad_output, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
else:
grad_t2 = DiagonaledMM._diagonaled_mm(grad_output, t1, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
return grad_t1, grad_t2, None, None, None, None, None
def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
diagonals_list = []
for j in range(-d * w, d, d):
diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
diagonal_mask[:-j] = 1
diagonals_list.append(diagonal_mask)
return torch.stack(diagonals_list, dim=-1)
@lru_cache()
def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
if isinstance(d, int):
affected_seq_len = w * d
mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
mask = mask[None, :, None, :]
else:
affected_seq_len = w * d.max()
head_masks = []
d_list = d.cpu().numpy().tolist()
for d in d_list:
one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
head_masks.append(one_head_mask)
mask = torch.stack(head_masks, dim=-2)
mask = mask[None, :, :, :]
ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
return affected_seq_len, mask.bool().to(device), ending_mask
def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
seq_len = input_tensor.size(1)
beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask, -float('inf'))
if not autoregressive:
ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
ending_input.masked_fill_(ending_mask, -float('inf'))
diagonaled_mm = DiagonaledMM.apply
# The non-tvm implementation is the default, we don't need to load the kernel at loading time.
# DiagonaledMM._get_function('float32', 'cuda')
|