import torch def attn_ref(q, k, v, b, sm_scale, dropout_p=0.0, causal=False, upcast=False): if upcast: q, k, v = q.float(), k.float(), v.float() if b is not None: b = b.float() if b is not None: if (b.shape[0] != q.shape[0]) or (b.shape[1] != q.shape[1]): b = b.expand(q.shape[0], q.shape[1], q.shape[2], k.shape[2]) ms = torch.arange(q.shape[2], device=q.device).unsqueeze(-1) ns = torch.arange(k.shape[2], device=q.device) p = torch.matmul(q, k.transpose(2, 3)) p *= sm_scale if b is not None: p += b if causal: p = torch.where(ms + k.shape[2] - q.shape[2] >= ns, p, float("-inf")) p = torch.softmax(p.float(), dim=-1).to(q.dtype) if dropout_p > 0.0: p = torch.dropout(p, dropout_p, train=True) ref_out = torch.matmul(p, v) return ref_out