import torch | |
from triton_flash_atn import _attention | |
# Define dimensions | |
batch_size = 2 | |
num_heads = 4 | |
seq_len = 128 | |
head_dim = 64 | |
# Create random input tensors for Q, K, V | |
q = torch.randn(batch_size, num_heads, seq_len, head_dim, | |
dtype=torch.float16, device='cuda') | |
k = torch.randn(batch_size, num_heads, seq_len, head_dim, | |
dtype=torch.float16, device='cuda') | |
v = torch.randn(batch_size, num_heads, seq_len, head_dim, | |
dtype=torch.float16, device='cuda') | |
# Define whether the attention is causal and the scaling factor | |
causal = False | |
sm_scale = 1.0 / (head_dim ** 0.5) | |
# Apply flash attention | |
attention = _attention.apply | |
output = attention(q, k, v, causal, sm_scale) | |
print(output) | |