omarmomen's picture
added notebook for generation
340c8dd
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This code is from AllenAI's Longformer:
https://github.com/allenai/longformer/
"""
from typing import Union
from functools import lru_cache
import torch
import os.path
class DiagonaledMM(torch.autograd.Function):
'''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling
this function from PyTorch
'''
function_dict = {} # save a list of functions, each has a different set of parameters
@staticmethod
def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 4, b2: int = 16):
'''Compiles a tvm function that computes diagonal_mm
args:
dtype: str in ['float64', 'float32', 'float16']
device: str in ['cpu' or 'cuda']
b0, b1, b2: size of tensor tiles. Very important for good performance
'''
import tvm # import the full tvm library here for compilation. Don't import at the top of the file in case we don't need to compile
from tvm.contrib import nvcc
@tvm.register_func
def tvm_callback_cuda_compile(code):
"""Use nvcc compiler for better perf."""
ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') # use old arch for this to work on old GPUs
return ptx
assert dtype in ['float16', 'float32', 'float64']
assert device in ['cpu', 'cuda']
device = None if device == 'cpu' else device
tgt_host="llvm"
b = tvm.var('b') # batch size
n = tvm.var('n') # sequence length
h = tvm.var('h') # number of heads
m = tvm.var('m') # hidden dimension
w = tvm.var('w') # window size
w_upper = tvm.var('w_upper') # window size to the right of the word. Should be `0` or `w`
padding = tvm.var('padding') # padding
transpose_t1 = tvm.var('transpose_t1') # t1 should be transposed
t1d3 = tvm.var('t1d3') # last dimension of t1
t3d3 = tvm.var('t3d3') # last dimension of t3 (the result tensor)
X = tvm.placeholder((b, n, h, t1d3), name='X', dtype=dtype) # first tensor
Y = tvm.placeholder((b, n, h, m), name='Y', dtype=dtype) # second tensor
k = tvm.reduce_axis((0, t1d3), name='k') # dimension to sum over
D = tvm.placeholder((h), name='D', dtype='int') # dilation per head
output_shape = (b, n, h, t3d3) # shape of the result tensor
algorithm = lambda l, i, q, j: tvm.sum(
tvm.if_then_else(
t3d3 == m, # if output dimension == m, then t1 is diagonaled (FIXME: This breaks if t3d3 == m == t1d3)
tvm.if_then_else(
transpose_t1 == 0,
tvm.if_then_else(
tvm.all(
i + D[q] * (k - w) >= 0,
i + D[q] * (k - w) < n,
),
X[l, i, q, k] * Y[l, i + D[q] * (k - w), q, j], # t1 is diagonaled
padding
),
tvm.if_then_else(
tvm.all(
i + D[q] * (k - w_upper) >= 0, # `w_upper` to handle the case `autoregressive=True`
i + D[q] * (k - w_upper) < n,
),
X[l, i + D[q] * (k - w_upper), q, (w_upper + w) - k] * Y[l, i + D[q] * (k - w_upper), q, j], # # t1 is diagonaled and should be transposed
padding
),
),
tvm.if_then_else(
tvm.all(
i + D[q] * (j - w) >= 0,
i + D[q] * (j - w) < n,
),
X[l, i, q, k] * Y[l, i + D[q] * (j - w), q, k], # t1 is not diagonaled, but the output tensor is going to be
padding
)
), axis=k)
Z = tvm.compute(output_shape, algorithm, name='Z') # automatically generate cuda code
s = tvm.create_schedule(Z.op)
print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
# split long axis into smaller chunks and assing each one to a separate GPU thread/block
ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0)
ZF = s.rfactor(Z, ki)
j_outer, j_inner = s[Z].split(s[Z].op.axis[-1], factor=b1)
i_outer, i_inner = s[Z].split(s[Z].op.axis[1], factor=b2)
s[Z].bind(j_outer, tvm.thread_axis("blockIdx.x"))
s[Z].bind(j_inner, tvm.thread_axis("threadIdx.y"))
s[Z].bind(i_outer, tvm.thread_axis("blockIdx.y"))
s[Z].bind(i_inner, tvm.thread_axis("threadIdx.z"))
tx = tvm.thread_axis("threadIdx.x")
s[Z].bind(s[Z].op.reduce_axis[0], tx)
s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0])
s[Z].set_store_predicate(tx.var.equal(0))
print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True)))
# compiling the automatically generated cuda code
diagonaled_mm = tvm.build(s, [X, Y, Z, D, w, w_upper, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='diagonaled_mm')
return diagonaled_mm
@staticmethod
def _get_lib_filename(dtype: str, device: str):
base_filename = 'longformer/lib/lib_diagonaled_mm'
return '{}_{}_{}.so'.format(base_filename, dtype, device)
@staticmethod
def _save_compiled_function(f, dtype: str, device: str):
if not os.path.exists('longformer/lib/'):
os.makedirs('longformer/lib/')
f.export_library(DiagonaledMM._get_lib_filename(dtype, device))
@staticmethod
def _load_compiled_function(dtype: str, device: str):
from tvm.module import load # this can be the small runtime python library, and doesn't need to be the whole thing
filename = DiagonaledMM._get_lib_filename(dtype, device)
current_dir = os.path.dirname(os.path.abspath(__file__))
potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../']
for potential_dir in potential_dirs:
filepath = '{}{}'.format(potential_dir, filename)
if os.path.isfile(filepath):
print('Loading tvm binary from: {}'.format(filepath))
return load(filepath)
return None
@staticmethod
def _get_function(dtype: str, device: str):
'''Loads the function from the disk or compile it'''
# A list of arguments that define the function
args = (dtype, device)
if args not in DiagonaledMM.function_dict:
diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device) # try to load from disk
if not diagonaled_mm:
print('Tvm binary not found. Compiling ...')
diagonaled_mm = DiagonaledMM._compile_function(dtype, device) # compile
DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device) # save to disk
# convert the tvm function into a pytorch function
from tvm.contrib import dlpack
diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm) # wrap it as a pytorch function
# save the function into a dictionary to be reused
DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch # save it in a dictionary for next time
return DiagonaledMM.function_dict[args]
@staticmethod
def _diagonaled_mm(t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int],
is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0,
autoregressive: bool = False):
'''Calls the compiled function after checking the input format. This function is called in three different modes.
t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores
t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context
t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of
the calculations in the backward pass.
'''
dtype = str(t1.dtype).split('.')[1]
device = t1.device.type
assert len(t1.shape) == 4
assert len(t1.shape) == len(t2.shape)
assert t1.shape[:3] == t2.shape[:3]
if isinstance(d, int): # if d is an integer, replace it with a tensor of the same length
# as number of heads, and it is filled with the same dilation value
d = t1.new_full(size=(t1.shape[2],), fill_value=d, dtype=torch.int, requires_grad=False)
assert len(d.shape) == 1
assert d.shape[0] == t1.shape[2] # number of dilation scores should match number of heads
b = t1.shape[0] # batch size
n = t1.shape[1] # sequence length
h = t1.shape[2] # number of heads
m = t2.shape[3] # hidden dimension
w_upper = 0 if autoregressive else w
c = w_upper + w + 1 # number of diagonals
if is_t1_diagonaled:
assert t1.shape[3] == c
r = t1.new_empty(b, n, h, m) # allocate spase for the result tensor
else:
assert not transpose_t1
assert t1.shape[3] == m
r = t1.new_empty(b, n, h, c) # allocate spase for the result tensor
# gets function from memory, from disk or compiles it from scratch
_diagonaled_mm_function = DiagonaledMM._get_function(dtype=dtype, device=device)
# The last argument to this function is a little hacky. It is the size of the last dimension of the result tensor
# We use it as a proxy to tell if t1_is_diagonaled or not (if t1 is diagonaled, result is not, and vice versa).
# The second reason is that the lambda expression in `_compile_function` is easier to express when the shape
# of the output is known
# This functions computes diagonal_mm then saves the result in `r`
if m == c:
# FIXME
print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}')
assert False
_diagonaled_mm_function(t1, t2, r, d, w, w_upper, padding, transpose_t1, m if is_t1_diagonaled else c)
return r
@staticmethod
def _prepare_tensors(t):
'''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch.
For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter.
TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`.
Here's an example to reporduce this issue:
import torch
print(torch.randn(1, 10).stride())
> (10, 1)
print(torch.randn(10, 1).t().contiguous().stride())
> (1, 1) # expected it to be (10, 1) as above
print(torch.randn(10, 2).t().contiguous().stride())
> (10, 1) # but gets the expected stride if the first dimension is > 1
'''
assert t.is_contiguous()
t_stride = list(t.stride())
t_size = list(t.size())
# Fix wrong stride information for the first dimension. This occures when batch_size=1
if t_size[0] == 1 and t_stride[0] == t_stride[1]:
# In this case, the stride of the first dimension should be the product
# of the sizes of all other dimensions
t_stride[0] = t_size[1] * t_size[2] * t_size[3]
t = t.as_strided(size=t_size, stride=t_stride)
return t
min_seq_len = 16 # unexpected output if seq_len < 16
@staticmethod
def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], is_t1_diagonaled: bool = False, padding: int = 0, autoregressive: bool = False) -> torch.Tensor:
'''Compuates diagonal_mm of t1 and t2.
args:
t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals).
t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`)
t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled
tensor, e.g. `key_layer` or `value_layer`
w: int = window size; number of attentions on each side of the word
d: torch.Tensor or int = dilation of attentions per attention head. If int, the same dilation value will be used for all
heads. If torch.Tensor, it should be 1D of lenth=number of attention heads
is_t1_diagonaled: is t1 a diagonaled or a regular tensor
padding: the padding value to use when accessing invalid locations. This is mainly useful when the padding
needs to be a very large negative value (to compute softmax of attentions). For other usecases,
please use zero padding.
autoregressive: if true, return only the lower triangle
returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals)
if t1 is diagonaed, result is non-diagonaled, and vice versa
'''
batch_size, seq_len, num_attention_heads, hidden_size = t1.size()
assert seq_len >= DiagonaledMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(DiagonaledMM.min_seq_len) # FIXME
ctx.save_for_backward(t1, t2)
ctx.w = w
ctx.d = d
ctx.is_t1_diagonaled = is_t1_diagonaled
ctx.autoregressive = autoregressive
t1 = DiagonaledMM._prepare_tensors(t1)
t2 = DiagonaledMM._prepare_tensors(t2)
# output = t1.mm(t2) # what would have been called if this was a regular matmul
output = DiagonaledMM._diagonaled_mm(t1, t2, w, d, is_t1_diagonaled=is_t1_diagonaled, padding=padding, autoregressive=autoregressive)
return output
@staticmethod
def backward(ctx, grad_output):
t1, t2 = ctx.saved_tensors
w = ctx.w
d = ctx.d
is_t1_diagonaled = ctx.is_t1_diagonaled
autoregressive = ctx.autoregressive
if not grad_output.is_contiguous():
grad_output = grad_output.contiguous() # tvm requires all input tensors to be contiguous
grad_output = DiagonaledMM._prepare_tensors(grad_output)
t1 = DiagonaledMM._prepare_tensors(t1)
t2 = DiagonaledMM._prepare_tensors(t2)
# http://cs231n.github.io/optimization-2/
# https://pytorch.org/docs/master/notes/extending.html
# grad_t1 = grad_output.mm(t2) # what would have been called if this was a regular matmul
grad_t1 = DiagonaledMM._diagonaled_mm(grad_output, t2, w, d, is_t1_diagonaled=not is_t1_diagonaled, autoregressive=autoregressive)
# grad_t2 = grad_output.t().mm(t1) # or `grad_t2 = t1.t().mm(grad_output).t()` because `(AB)^T = B^TA^T`
if is_t1_diagonaled:
grad_t2 = DiagonaledMM._diagonaled_mm(t1, grad_output, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
else:
grad_t2 = DiagonaledMM._diagonaled_mm(grad_output, t1, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive)
return grad_t1, grad_t2, None, None, None, None, None
def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int):
diagonals_list = []
for j in range(-d * w, d, d):
diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8)
diagonal_mask[:-j] = 1
diagonals_list.append(diagonal_mask)
return torch.stack(diagonals_list, dim=-1)
@lru_cache()
def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str):
if isinstance(d, int):
affected_seq_len = w * d
mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
mask = mask[None, :, None, :]
else:
affected_seq_len = w * d.max()
head_masks = []
d_list = d.cpu().numpy().tolist()
for d in d_list:
one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d)
head_masks.append(one_head_mask)
mask = torch.stack(head_masks, dim=-2)
mask = mask[None, :, :, :]
ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device)
return affected_seq_len, mask.bool().to(device), ending_mask
def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor:
affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device)
seq_len = input_tensor.size(1)
beginning_input = input_tensor[:, :affected_seq_len, :, :w+1]
beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size())
beginning_input.masked_fill_(beginning_mask, -float('inf'))
if not autoregressive:
ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):]
ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size())
ending_input.masked_fill_(ending_mask, -float('inf'))
diagonaled_mm = DiagonaledMM.apply
# The non-tvm implementation is the default, we don't need to load the kernel at loading time.
# DiagonaledMM._get_function('float32', 'cuda')