|
from dataclasses import dataclass, field |
|
from einops import rearrange, repeat |
|
import math |
|
import torch |
|
from torch.amp.autocast_mode import autocast |
|
import torch.nn as nn |
|
from transformers.activations import ACT2FN |
|
from typing import cast |
|
|
|
|
|
try: |
|
from flash_attn.bert_padding import pad_input, unpad_input |
|
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding |
|
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention |
|
from flash_attn.ops.fused_dense import FusedDense |
|
except ImportError: |
|
print("flash_attn not found, using default implementations") |
|
pad_input = unpad_input = FlashRotaryEmbedding = FlashCrossAttentio = FlashSelfAttention = FusedDense = None |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
"""Rotary positional embedding (RoPE). See https://www.youtube.com/watch?v=C6rV8BsrrCc""" |
|
|
|
def __init__( |
|
self, |
|
d_rotary: int, |
|
rotary_base: float = 10000.0, |
|
initial_cos_sin_cache_len: int = 2048, |
|
device: torch.device | None = None, |
|
) -> None: |
|
super().__init__() |
|
self.d_rotary = d_rotary |
|
self.rotary_base = rotary_base |
|
self.device = device |
|
self.dtype = torch.float32 |
|
self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len) |
|
|
|
def _update_cos_sin_cache( |
|
self, |
|
seqlen: int, |
|
device: str | None = None, |
|
dtype: torch.dtype | None = None, |
|
) -> None: |
|
|
|
self._max_seqlen = seqlen |
|
|
|
|
|
m = torch.arange( |
|
seqlen, |
|
device=device, |
|
dtype=torch.float32, |
|
) |
|
theta_i = 1.0 / ( |
|
self.rotary_base ** ( |
|
torch.arange( |
|
start=0, |
|
end=self.d_rotary, |
|
step=2, |
|
device=device, |
|
dtype=torch.float32, |
|
) / self.d_rotary |
|
) |
|
) |
|
|
|
|
|
m_theta_i = torch.outer(m, theta_i) |
|
self._cos_cached = torch.cos(m_theta_i).to(dtype) |
|
self._sin_cached = torch.sin(m_theta_i).to(dtype) |
|
|
|
|
|
""" |
|
if scale_base is not None: |
|
scale = ( |
|
torch.arange( |
|
start=0, |
|
end=self.d_rotary, |
|
step=2, |
|
device=self.device, |
|
dtype=torch.float32, |
|
) + 0.4 * self.d_rotary |
|
) / (1.4 * self.d_rotary) |
|
power = ( |
|
torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2 |
|
) / scale_base |
|
scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
|
self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype) |
|
self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype) |
|
""" |
|
|
|
def _apply_rotary_emb_qkv( |
|
self, |
|
x: torch.FloatTensor, |
|
cos: torch.FloatTensor, |
|
sin: torch.FloatTensor, |
|
) -> torch.FloatTensor: |
|
seqlen = x.shape[1] |
|
x_to_rotate = x[..., :self.d_rotary] |
|
x_to_keep_unrotated = x[..., self.d_rotary:] |
|
x1, x2 = x_to_rotate.chunk(2, dim=-1) |
|
broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d" |
|
c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange) |
|
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] |
|
x_rotated = cast( |
|
torch.FloatTensor, |
|
torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype) |
|
) |
|
return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1) |
|
|
|
def forward( |
|
self, |
|
x: torch.FloatTensor, |
|
seqlen_offset: int = 0, |
|
) -> torch.FloatTensor: |
|
if ( |
|
not self._max_seqlen |
|
or self._max_seqlen < x.shape[1] + seqlen_offset |
|
or self._cos_cached.device != x.device |
|
or self._cos_cached.dtype != x.dtype |
|
or (self.training and self._cos_cached.is_inference()) |
|
): |
|
self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype) |
|
return self._apply_rotary_emb_qkv( |
|
x, |
|
cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]), |
|
cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]), |
|
) |
|
|
|
|
|
class SelfAttention(nn.Module): |
|
def __init__( |
|
self, |
|
qk_scale: float | None = None, |
|
attention_dropout: float = 0.0, |
|
) -> None: |
|
super().__init__() |
|
self.qk_scale = qk_scale |
|
self.dropout = nn.Dropout(attention_dropout) |
|
|
|
|
|
@autocast("cpu", enabled=False) |
|
@autocast("cuda", enabled=False) |
|
def forward( |
|
self, |
|
qkv: torch.FloatTensor, |
|
causal: bool = True, |
|
key_padding_mask: torch.BoolTensor | None = None, |
|
) -> torch.FloatTensor: |
|
batch_size, seqlen = qkv.shape[0], qkv.shape[1] |
|
q, k, v = qkv.unbind(dim=2) |
|
q = q.to(torch.float32) |
|
k = k.to(torch.float32) |
|
qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) |
|
|
|
scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) |
|
|
|
if key_padding_mask: |
|
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) |
|
padding_mask.masked_fill_(key_padding_mask, 0.0) |
|
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") |
|
|
|
if causal: |
|
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) |
|
scores = scores + causal_mask.to(dtype=scores.dtype) |
|
|
|
attention = torch.softmax(scores, dim=-1).to(v.dtype) |
|
attention = self.dropout(attention) |
|
|
|
output = torch.einsum("bhts,bshd->bthd", attention, v) |
|
return cast(torch.FloatTensor, output) |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
qk_scale: float | None = None, |
|
attention_dropout: float = 0.0, |
|
) -> None: |
|
super().__init__() |
|
self.qk_scale = qk_scale |
|
self.dropout = nn.Dropout(attention_dropout) |
|
|
|
|
|
@autocast("cpu", enabled=False) |
|
@autocast("cuda", enabled=False) |
|
def forward( |
|
self, |
|
q: torch.FloatTensor, |
|
kv: torch.FloatTensor, |
|
causal: bool = True, |
|
key_padding_mask: torch.BoolTensor | None = None, |
|
) -> torch.FloatTensor: |
|
batch_size, seqlen_q = q.shape[0], q.shape[1] |
|
seqlen_k = kv.shape[1] |
|
if kv.shape[3] != q.shape[2]: |
|
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) |
|
k, v = kv.unbind(dim=2) |
|
q = cast(torch.FloatTensor, q.to(torch.float32)) |
|
k = k.to(torch.float32) |
|
qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) |
|
|
|
scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) |
|
|
|
if key_padding_mask: |
|
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) |
|
padding_mask.masked_fill_(key_padding_mask, 0.0) |
|
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") |
|
|
|
if causal: |
|
rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") |
|
cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) |
|
causal_mask = cols > rows + seqlen_k - seqlen_q |
|
scores = scores.masked_fill(causal_mask, -10000.0) |
|
|
|
attention = torch.softmax(scores, dim=-1).to(v.dtype) |
|
attention = self.dropout(attention) |
|
|
|
output = torch.einsum("bhts,bshd->bthd", attention, v) |
|
return cast(torch.FloatTensor, output) |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, |
|
d_embedding: int, |
|
act_fn: str = "gelu_new", |
|
) -> None: |
|
super().__init__() |
|
n_inner = 4 * d_embedding |
|
self.fc1 = nn.Linear(d_embedding, n_inner) |
|
self.act = ACT2FN[act_fn] |
|
self.fc2 = nn.Linear(n_inner, d_embedding) |
|
|
|
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: |
|
x = self.fc1(x) |
|
x = self.act(x) |
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
@dataclass |
|
class KVCache: |
|
"""Options for model to calculate and store context during inference.""" |
|
max_seqlen: int |
|
max_batch_size: int |
|
seqlen_offset: int |
|
batch_size_offset: int |
|
kv_block_map: dict[int, torch.Tensor] = field(default_factory=dict) |
|
lengths_per_sample: torch.Tensor | None = None |
|
|
|
|
|
class MHA(nn.Module): |
|
"""Multi-head attention block.""" |
|
|
|
def __init__( |
|
self, |
|
d_embedding: int, |
|
n_attn_heads: int, |
|
block_n: int, |
|
initial_cos_sin_cache_len: int, |
|
attn_pdrop: float, |
|
use_flash_rotary: bool, |
|
use_flash_attn: bool, |
|
use_fused_dense: bool, |
|
checkpointing: bool, |
|
) -> None: |
|
super().__init__() |
|
|
|
|
|
rotary_cls = ( |
|
FlashRotaryEmbedding |
|
if use_flash_rotary and FlashRotaryEmbedding is not None |
|
else RotaryEmbedding |
|
) |
|
self.rotary_emb = rotary_cls( |
|
|
|
d_rotary=32, |
|
initial_cos_sin_cache_len=initial_cos_sin_cache_len, |
|
) |
|
|
|
|
|
self_attn_cls = ( |
|
FlashSelfAttention |
|
if use_flash_attn and FlashSelfAttention is not None |
|
else SelfAttention |
|
) |
|
self.inner_self_attn = self_attn_cls(attention_dropout=attn_pdrop) |
|
|
|
|
|
cross_attn_cls = ( |
|
FlashCrossAttention |
|
if use_flash_attn and FlashCrossAttention is not None |
|
else CrossAttention |
|
) |
|
self.inner_cross_attn = cross_attn_cls(attention_dropout=attn_pdrop) |
|
|
|
|
|
self.n_attn_heads = n_attn_heads |
|
self.d_head = d_embedding // n_attn_heads |
|
linear_cls = ( |
|
FusedDense |
|
if use_fused_dense and FusedDense is not None |
|
else nn.Linear |
|
) |
|
self.Wqkv = linear_cls( |
|
d_embedding, |
|
self.d_head * (3 * self.n_attn_heads), |
|
) |
|
self.fc_out = linear_cls(d_embedding, d_embedding) |
|
|
|
|
|
self.using_flash_attn = self_attn_cls is FlashSelfAttention |
|
self.block_n = block_n |
|
self.checkpointing = checkpointing |
|
|
|
def _forward_self_attn( |
|
self, |
|
qkv: torch.FloatTensor, |
|
key_padding_mask: torch.BoolTensor | None, |
|
) -> torch.FloatTensor: |
|
qkv = cast( |
|
torch.FloatTensor, |
|
torch.cat( |
|
[ |
|
self.rotary_emb(qkv[:, :, :2, :, :]), |
|
qkv[:, :, 2, :, :], |
|
], |
|
dim=2, |
|
) |
|
) |
|
|
|
if self.using_flash_attn and unpad_input and pad_input: |
|
batch_size, seqlen = qkv.shape[0], qkv.shape[1] |
|
cu_seqlens, max_seqlen, indices = None, None, None |
|
|
|
|
|
if key_padding_mask: |
|
qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask) |
|
|
|
if self.checkpointing: |
|
attn_output = torch.utils.checkpoint.checkpoint( |
|
self.inner_self_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen |
|
) |
|
else: |
|
attn_output = self.inner_self_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device) |
|
|
|
|
|
if key_padding_mask: |
|
return pad_input(attn_output, indices, batch_size, seqlen) |
|
else: |
|
return attn_output |
|
|
|
if self.checkpointing: |
|
return torch.utils.checkpoint.checkpoint(self.inner_self_attn, qkv, key_padding_mask=key_padding_mask) |
|
else: |
|
return self.inner_self_attn(qkv, key_padding_mask=key_padding_mask) |
|
|
|
def _update_kv_cache( |
|
self, |
|
kv: torch.FloatTensor, |
|
kv_cache: KVCache, |
|
block_n: int, |
|
) -> None: |
|
if block_n not in kv_cache.kv_block_map: |
|
kv_cache.kv_block_map[block_n] = torch.empty( |
|
kv_cache.max_batch_size, |
|
kv_cache.max_seqlen, |
|
2, |
|
kv.shape[-2], |
|
kv.shape[-1], |
|
dtype=kv.dtype, |
|
device=kv.device, |
|
) |
|
kv_cache.kv_block_map[block_n][ |
|
kv_cache.batch_size_offset: kv_cache.batch_size_offset + kv.shape[0], |
|
kv_cache.seqlen_offset: kv_cache.seqlen_offset + kv.shape[1], |
|
... |
|
] = kv |
|
|
|
def _forward_cross_attn( |
|
self, |
|
qkv: torch.FloatTensor, |
|
kv_cache: KVCache, |
|
key_padding_mask: torch.BoolTensor | None, |
|
) -> torch.FloatTensor: |
|
qk = qkv[:, :, :2, :, :] |
|
qk = self.rotary_emb( |
|
qk, |
|
seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset, |
|
) |
|
v = cast(torch.FloatTensor, qkv[:, :, 2, :, :]) |
|
q = qk[:, :, 0, :, :] |
|
kv = torch.cat( |
|
[ |
|
qk[:, :, 1, :, :].unsqueeze(2), |
|
v.unsqueeze(2), |
|
], |
|
dim=2, |
|
) |
|
self._update_kv_cache(kv, kv_cache, self.block_n) |
|
causal = False |
|
|
|
if self.using_flash_attn and unpad_input and pad_input: |
|
batch_size, seqlen_q = q.shape[0], q.shape[1] |
|
seqlen_k = kv.shape[1] |
|
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, indices_q = ( |
|
None, |
|
None, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
if key_padding_mask: |
|
kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) |
|
|
|
if seqlen_q == 1: |
|
key_padding_mask = cast(torch.BoolTensor, torch.ones(batch_size, 1, device=q.device)) |
|
elif seqlen_q != seqlen_k: |
|
key_padding_mask = cast(torch.BoolTensor, key_padding_mask[:, -seqlen_q:]) |
|
|
|
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask) |
|
|
|
if self.checkpointing: |
|
attn_output = torch.utils.checkpoint.checkpoint( |
|
self.inner_cross_attn, |
|
q, |
|
kv, |
|
causal=causal, |
|
cu_seqlens=cu_seqlens_q, |
|
max_seqlen=max_seqlen_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_k=max_seqlen_k, |
|
) |
|
else: |
|
attn_output = self.inner_cross_attn( |
|
q, |
|
kv, |
|
causal=causal, |
|
cu_seqlens=cu_seqlens_q, |
|
max_seqlen=max_seqlen_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_k=max_seqlen_k, |
|
) |
|
|
|
if key_padding_mask: |
|
return pad_input(attn_output, indices_q, batch_size, max_seqlen_q) |
|
else: |
|
return attn_output |
|
|
|
if self.checkpointing: |
|
return torch.utils.checkpoint.checkpoint( |
|
self.inner_cross_attn, |
|
q, |
|
kv, |
|
key_padding_mask=key_padding_mask, |
|
causal=causal, |
|
) |
|
else: |
|
return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal) |
|
|
|
def forward( |
|
self, |
|
x: torch.FloatTensor, |
|
kv_cache: KVCache | None = None, |
|
key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None, |
|
) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
|
if key_padding_mask is not None: |
|
key_padding_mask = cast(torch.BoolTensor, key_padding_mask.bool()) |
|
|
|
qkv = self.Wqkv(x) |
|
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.d_head) |
|
if kv_cache is None: |
|
attn_output = self._forward_self_attn(qkv, key_padding_mask) |
|
else: |
|
attn_output = self._forward_cross_attn(qkv, kv_cache, key_padding_mask) |
|
|
|
output = rearrange(attn_output, "... h d -> ... (h d)") |
|
output = self.fc_out(output) |
|
return output |
|
|
|
|
|
class ParallelAttentionBlock(nn.Module): |
|
"""Calculates attention and MLP in parallel.""" |
|
|
|
def __init__( |
|
self, |
|
resid_pdrop: float, |
|
layer_norm_epsilon: float, |
|
d_embedding: int, |
|
n_attn_heads: int, |
|
block_n: int, |
|
initial_cos_sin_cache_len: int, |
|
attn_pdrop: float, |
|
use_flash_rotary: bool = True, |
|
use_flash_attn: bool = True, |
|
use_fused_dense: bool = True, |
|
checkpointing: bool = False, |
|
) -> None: |
|
super().__init__() |
|
self.layer_norm = nn.LayerNorm(d_embedding, eps=layer_norm_epsilon) |
|
self.block_n = block_n |
|
self.multi_head_attention = MHA( |
|
d_embedding=d_embedding, |
|
n_attn_heads=n_attn_heads, |
|
block_n=block_n, |
|
initial_cos_sin_cache_len=initial_cos_sin_cache_len, |
|
attn_pdrop=attn_pdrop, |
|
use_flash_rotary=use_flash_rotary, |
|
use_flash_attn=use_flash_attn, |
|
use_fused_dense=use_fused_dense, |
|
checkpointing=checkpointing, |
|
) |
|
self.mlp = MLP(d_embedding) |
|
self.dropout = nn.Dropout(resid_pdrop) |
|
|
|
def forward( |
|
self, |
|
x: torch.FloatTensor, |
|
kv_cache: KVCache | None = None, |
|
key_padding_mask: torch.BoolTensor | None = None, |
|
) -> torch.FloatTensor: |
|
residual = x |
|
x = self.layer_norm(x) |
|
attn_outputs = self.multi_head_attention( |
|
x, |
|
kv_cache=kv_cache, |
|
key_padding_mask=key_padding_mask, |
|
) |
|
mlp_outputs = self.mlp(x) |
|
return self.dropout(attn_outputs + mlp_outputs) + residual |
|
|