simplified_phi2 / attention.py
BucketOfFish's picture
Got output to match Phi2 exactly
3649bbb
from dataclasses import dataclass, field
from einops import rearrange, repeat
import math
import torch
from torch.amp.autocast_mode import autocast
import torch.nn as nn
from transformers.activations import ACT2FN
from typing import cast
# if flash_attn exists
try:
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
print("flash_attn not found, using default implementations")
pad_input = unpad_input = FlashRotaryEmbedding = FlashCrossAttentio = FlashSelfAttention = FusedDense = None
class RotaryEmbedding(nn.Module):
"""Rotary positional embedding (RoPE). See https://www.youtube.com/watch?v=C6rV8BsrrCc"""
def __init__(
self,
d_rotary: int,
rotary_base: float = 10000.0,
initial_cos_sin_cache_len: int = 2048,
device: torch.device | None = None,
) -> None:
super().__init__()
self.d_rotary = d_rotary
self.rotary_base = rotary_base
self.device = device
self.dtype = torch.float32
self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
def _update_cos_sin_cache(
self,
seqlen: int,
device: str | None = None,
dtype: torch.dtype | None = None,
) -> None:
# only call this function when seqlen is larger than _max_seqlen
self._max_seqlen = seqlen
# m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
m = torch.arange(
seqlen,
device=device,
dtype=torch.float32,
)
theta_i = 1.0 / (
self.rotary_base ** (
torch.arange(
start=0,
end=self.d_rotary,
step=2,
device=device,
dtype=torch.float32,
) / self.d_rotary
)
)
# torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
# TODO: does this matter if I'm disabling torch.autocast?
m_theta_i = torch.outer(m, theta_i)
self._cos_cached = torch.cos(m_theta_i).to(dtype)
self._sin_cached = torch.sin(m_theta_i).to(dtype)
# TODO: scale_base caching is labelled as not yet done in Phi2
"""
if scale_base is not None:
scale = (
torch.arange(
start=0,
end=self.d_rotary,
step=2,
device=self.device,
dtype=torch.float32,
) + 0.4 * self.d_rotary
) / (1.4 * self.d_rotary)
power = (
torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2
) / scale_base
scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1")
self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype)
self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype)
"""
def _apply_rotary_emb_qkv(
self,
x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head)
cos: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
) -> torch.FloatTensor:
seqlen = x.shape[1]
x_to_rotate = x[..., :self.d_rotary]
x_to_keep_unrotated = x[..., self.d_rotary:]
x1, x2 = x_to_rotate.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_rotary/2)
broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
x_rotated = cast(
torch.FloatTensor,
torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
)
return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1)
def forward(
self,
x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head)
seqlen_offset: int = 0, # each sequence is shifted by this amount - used in inference with KV cache
) -> torch.FloatTensor:
if (
not self._max_seqlen
or self._max_seqlen < x.shape[1] + seqlen_offset
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
or (self.training and self._cos_cached.is_inference())
):
self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype)
return self._apply_rotary_emb_qkv(
x,
cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]),
)
class SelfAttention(nn.Module):
def __init__(
self,
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
attention_dropout: float = 0.0,
) -> None:
super().__init__()
self.qk_scale = qk_scale
self.dropout = nn.Dropout(attention_dropout)
# autocast is manually disabled to avoid `torch.einsum` using float16, which might lead to overflow
@autocast("cpu", enabled=False)
@autocast("cuda", enabled=False)
def forward(
self,
qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head)
causal: bool = True,
key_padding_mask: torch.BoolTensor | None = None,
) -> torch.FloatTensor:
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
q, k, v = qkv.unbind(dim=2)
q = q.to(torch.float32)
k = k.to(torch.float32)
qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale)
if key_padding_mask:
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal:
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
scores = scores + causal_mask.to(dtype=scores.dtype)
attention = torch.softmax(scores, dim=-1).to(v.dtype)
attention = self.dropout(attention)
output = torch.einsum("bhts,bshd->bthd", attention, v) # dim: (batch_size, seqlen, n_heads, d_head)
return cast(torch.FloatTensor, output)
class CrossAttention(nn.Module):
def __init__(
self,
qk_scale: float | None = None, # will use 1/sqrt(d) if set to None
attention_dropout: float = 0.0,
) -> None:
super().__init__()
self.qk_scale = qk_scale
self.dropout = nn.Dropout(attention_dropout)
# autocast is manually disabled to avoid `torch.einsum` using float16, which might lead to overflow
@autocast("cpu", enabled=False)
@autocast("cuda", enabled=False)
def forward(
self,
q: torch.FloatTensor, # dim: (batch_size, seqlen_q, n_heads, d_head)
kv: torch.FloatTensor, # dim: (batch_size, seqlen_kv, 2, n_heads, d_head)
causal: bool = True,
key_padding_mask: torch.BoolTensor | None = None,
) -> torch.FloatTensor:
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
if kv.shape[3] != q.shape[2]: # repeat kv n_heads dim to match q n_heads
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
k, v = kv.unbind(dim=2)
q = cast(torch.FloatTensor, q.to(torch.float32))
k = k.to(torch.float32)
qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale)
if key_padding_mask:
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device)
padding_mask.masked_fill_(key_padding_mask, 0.0)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal:
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long)
causal_mask = cols > rows + seqlen_k - seqlen_q
scores = scores.masked_fill(causal_mask, -10000.0)
attention = torch.softmax(scores, dim=-1).to(v.dtype)
attention = self.dropout(attention)
output = torch.einsum("bhts,bshd->bthd", attention, v) # dim: (batch_size, seqlen_q, n_heads, d_head)
return cast(torch.FloatTensor, output)
class MLP(nn.Module):
def __init__(
self,
d_embedding: int,
act_fn: str = "gelu_new",
) -> None:
super().__init__()
n_inner = 4 * d_embedding
self.fc1 = nn.Linear(d_embedding, n_inner)
self.act = ACT2FN[act_fn]
self.fc2 = nn.Linear(n_inner, d_embedding)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
@dataclass
class KVCache:
"""Options for model to calculate and store context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int
batch_size_offset: int
kv_block_map: dict[int, torch.Tensor] = field(default_factory=dict)
lengths_per_sample: torch.Tensor | None = None
class MHA(nn.Module):
"""Multi-head attention block."""
def __init__(
self,
d_embedding: int,
n_attn_heads: int,
block_n: int,
initial_cos_sin_cache_len: int, # length of cache for rotary embedding
attn_pdrop: float,
use_flash_rotary: bool, # use flash rotary embedding if possible
use_flash_attn: bool, # use flash attention if possible
use_fused_dense: bool, # use fused dense layer if possible
checkpointing: bool, # torch.utils.checkpoint
) -> None:
super().__init__()
# rotary embedding
rotary_cls = (
FlashRotaryEmbedding
if use_flash_rotary and FlashRotaryEmbedding is not None
else RotaryEmbedding
)
self.rotary_emb = rotary_cls(
# d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head
d_rotary=32, # TODO: figure out why Phi2 uses this
initial_cos_sin_cache_len=initial_cos_sin_cache_len,
)
# self attention
self_attn_cls = (
FlashSelfAttention
if use_flash_attn and FlashSelfAttention is not None
else SelfAttention
)
self.inner_self_attn = self_attn_cls(attention_dropout=attn_pdrop)
# cross attention
cross_attn_cls = (
FlashCrossAttention
if use_flash_attn and FlashCrossAttention is not None
else CrossAttention
)
self.inner_cross_attn = cross_attn_cls(attention_dropout=attn_pdrop)
# MLP
self.n_attn_heads = n_attn_heads
self.d_head = d_embedding // n_attn_heads
linear_cls = (
FusedDense
if use_fused_dense and FusedDense is not None
else nn.Linear
)
self.Wqkv = linear_cls(
d_embedding,
self.d_head * (3 * self.n_attn_heads), # calculating q, k, v for all heads in block simultaneously
)
self.fc_out = linear_cls(d_embedding, d_embedding)
# settings
self.using_flash_attn = self_attn_cls is FlashSelfAttention
self.block_n = block_n
self.checkpointing = checkpointing
def _forward_self_attn(
self,
qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head)
key_padding_mask: torch.BoolTensor | None,
) -> torch.FloatTensor:
qkv = cast(
torch.FloatTensor,
torch.cat(
[
self.rotary_emb(qkv[:, :, :2, :, :]), # qk
qkv[:, :, 2, :, :], # v
],
dim=2,
)
)
if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
cu_seqlens, max_seqlen, indices = None, None, None
# unpad input and retrieve `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
if key_padding_mask:
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask)
if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(
self.inner_self_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
else:
attn_output = self.inner_self_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device)
# repad output
if key_padding_mask:
return pad_input(attn_output, indices, batch_size, seqlen)
else:
return attn_output
if self.checkpointing:
return torch.utils.checkpoint.checkpoint(self.inner_self_attn, qkv, key_padding_mask=key_padding_mask)
else:
return self.inner_self_attn(qkv, key_padding_mask=key_padding_mask)
def _update_kv_cache(
self,
kv: torch.FloatTensor, # dim: (batch_size, seqlen, 2, n_heads, d_head)
kv_cache: KVCache,
block_n: int,
) -> None:
if block_n not in kv_cache.kv_block_map:
kv_cache.kv_block_map[block_n] = torch.empty(
kv_cache.max_batch_size,
kv_cache.max_seqlen,
2,
kv.shape[-2], # n_heads
kv.shape[-1], # d_head
dtype=kv.dtype,
device=kv.device,
)
batch_start = kv_cache.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = kv_cache.seqlen_offset
sequence_end = sequence_start + kv.shape[1]
# TODO: figure out why they're doing this
if sequence_end >= kv_cache.max_seqlen:
kv_cache.kv_block_map[block_n] = torch.concatenate(
(kv_cache.kv_block_map[block_n], kv),
dim=1,
)
kv_cache.kv_block_map[block_n][
batch_start:batch_end,
sequence_start:sequence_end,
...
] = kv
kv = kv_cache.kv_block_map[block_n][
batch_start:batch_end,
:sequence_end,
...
]
return kv
def _forward_cross_attn(
self,
qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head)
kv_cache: KVCache,
key_padding_mask: torch.BoolTensor | None,
) -> torch.FloatTensor:
qk = qkv[:, :, :2, :, :]
qk = self.rotary_emb(
qk,
seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset,
)
v = cast(torch.FloatTensor, qkv[:, :, 2, :, :])
q = qk[:, :, 0, :, :]
kv = torch.cat(
[
qk[:, :, 1, :, :].unsqueeze(2),
v.unsqueeze(2),
],
dim=2,
)
kv = self._update_kv_cache(kv, kv_cache, self.block_n)
causal = (kv_cache.seqlen_offset == 0)
if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = kv.shape[1]
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, indices_q = (
None,
None,
None,
None,
None,
)
# unpad input and retrieve `cu_seqlens` and `max_seqlen` to be used by `flash-attn`
if key_padding_mask:
kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask)
if seqlen_q == 1:
key_padding_mask = cast(torch.BoolTensor, torch.ones(batch_size, 1, device=q.device))
elif seqlen_q != seqlen_k:
key_padding_mask = cast(torch.BoolTensor, key_padding_mask[:, -seqlen_q:])
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask)
if self.checkpointing:
attn_output = torch.utils.checkpoint.checkpoint(
self.inner_cross_attn,
q,
kv,
causal=causal,
cu_seqlens=cu_seqlens_q,
max_seqlen=max_seqlen_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k,
)
else:
attn_output = self.inner_cross_attn(
q,
kv,
causal=causal,
cu_seqlens=cu_seqlens_q,
max_seqlen=max_seqlen_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k,
)
if key_padding_mask:
return pad_input(attn_output, indices_q, batch_size, max_seqlen_q)
else:
return attn_output
if self.checkpointing:
return torch.utils.checkpoint.checkpoint(
self.inner_cross_attn,
q,
kv,
key_padding_mask=key_padding_mask,
causal=causal,
)
else:
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal)
def forward(
self,
x: torch.FloatTensor, # dim: (batch_size, seqlen, d_embedding)
kv_cache: KVCache | None = None,
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
if key_padding_mask is not None:
key_padding_mask = cast(torch.BoolTensor, key_padding_mask.bool()) # make sure it's bool and not int
qkv = self.Wqkv(x) # dim: (batch_size, seqlen, 3*n_heads*d_head)
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.d_head) # dim: (batch_size, seqlen, 3, n_heads, d_head)
if kv_cache is None:
attn_output = self._forward_self_attn(qkv, key_padding_mask)
else:
attn_output = self._forward_cross_attn(qkv, kv_cache, key_padding_mask)
output = rearrange(attn_output, "... h d -> ... (h d)")
output = self.fc_out(output)
return output
class ParallelAttentionBlock(nn.Module):
"""Calculates attention and MLP in parallel."""
def __init__(
self,
resid_pdrop: float, # a bit of a misnomer, right?
layer_norm_epsilon: float,
d_embedding: int,
n_attn_heads: int,
block_n: int,
initial_cos_sin_cache_len: int, # length of cache for rotary embedding
attn_pdrop: float,
use_flash_rotary: bool = True, # use flash rotary embedding if possible
use_flash_attn: bool = True, # use flash attention if possible
use_fused_dense: bool = True, # use fused dense layer if possible
checkpointing: bool = False, # torch.utils.checkpoint
) -> None:
super().__init__()
self.layer_norm = nn.LayerNorm(d_embedding, eps=layer_norm_epsilon)
self.block_n = block_n
self.multi_head_attention = MHA(
d_embedding=d_embedding,
n_attn_heads=n_attn_heads,
block_n=block_n,
initial_cos_sin_cache_len=initial_cos_sin_cache_len,
attn_pdrop=attn_pdrop,
use_flash_rotary=use_flash_rotary,
use_flash_attn=use_flash_attn,
use_fused_dense=use_fused_dense,
checkpointing=checkpointing,
)
self.mlp = MLP(d_embedding)
self.dropout = nn.Dropout(resid_pdrop)
def forward(
self,
x: torch.FloatTensor, # dim: (batch_size, seq_len, d_embedding)
kv_cache: KVCache | None = None,
key_padding_mask: torch.BoolTensor | None = None,
) -> torch.FloatTensor:
residual = x
x = self.layer_norm(x) # each token (dim: d_embedding) is normalized individually
attn_outputs = self.multi_head_attention(
x,
kv_cache=kv_cache,
key_padding_mask=key_padding_mask,
)
mlp_outputs = self.mlp(x)
x = self.dropout(attn_outputs + mlp_outputs) + residual
return x