File size: 868 Bytes
f225bf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
def attn_ref(q, k, v, b, sm_scale, dropout_p=0.0, causal=False, upcast=False):
if upcast:
q, k, v = q.float(), k.float(), v.float()
if b is not None:
b = b.float()
if b is not None:
if (b.shape[0] != q.shape[0]) or (b.shape[1] != q.shape[1]):
b = b.expand(q.shape[0], q.shape[1], q.shape[2], k.shape[2])
ms = torch.arange(q.shape[2], device=q.device).unsqueeze(-1)
ns = torch.arange(k.shape[2], device=q.device)
p = torch.matmul(q, k.transpose(2, 3))
p *= sm_scale
if b is not None:
p += b
if causal:
p = torch.where(ms + k.shape[2] - q.shape[2] >= ns, p, float("-inf"))
p = torch.softmax(p.float(), dim=-1).to(q.dtype)
if dropout_p > 0.0:
p = torch.dropout(p, dropout_p, train=True)
ref_out = torch.matmul(p, v)
return ref_out
|