from dataclasses import dataclass, field from einops import rearrange, repeat import math import torch from torch.amp.autocast_mode import autocast import torch.nn as nn from transformers.activations import ACT2FN from typing import cast # if flash_attn exists try: from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention from flash_attn.ops.fused_dense import FusedDense except ImportError: print("flash_attn not found, using default implementations") pad_input = unpad_input = FlashRotaryEmbedding = FlashCrossAttentio = FlashSelfAttention = FusedDense = None class RotaryEmbedding(nn.Module): """Rotary positional embedding (RoPE) from Phi2. See https://www.youtube.com/watch?v=C6rV8BsrrCc """ def __init__( self, d_rotary: int, rotary_base: float = 10000.0, initial_cos_sin_cache_len: int = 2048, device: torch.device | None = None, ) -> None: super().__init__() self.d_rotary = d_rotary self.rotary_base = rotary_base self.device = device self.dtype = torch.float32 self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len) def _update_cos_sin_cache( self, seqlen: int, device: str | None = None, dtype: torch.dtype | None = None, ) -> None: # only call this function when seqlen is larger than _max_seqlen self._max_seqlen = seqlen # m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2] m = torch.arange( seqlen, device=device, dtype=torch.float32, ) theta_i = 1.0 / ( self.rotary_base ** ( torch.arange( start=0, end=self.d_rotary, step=2, device=device, dtype=torch.float32, ) / self.d_rotary ) ) # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp # TODO: does this matter if I'm disabling torch.autocast? m_theta_i = torch.outer(m, theta_i) self._cos_cached = torch.cos(m_theta_i).to(dtype) self._sin_cached = torch.sin(m_theta_i).to(dtype) # TODO: scale_base caching is labelled as not yet done in Phi2 """ if scale_base is not None: scale = ( torch.arange( start=0, end=self.d_rotary, step=2, device=self.device, dtype=torch.float32, ) + 0.4 * self.d_rotary ) / (1.4 * self.d_rotary) power = ( torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2 ) / scale_base scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1") self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype) self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype) """ def _apply_rotary_emb_qkv( self, x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head) cos: torch.FloatTensor, # dim: (_max_seqlen, d_rotary) sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary) ) -> torch.FloatTensor: seqlen = x.shape[1] x_to_rotate = x[..., :self.d_rotary] x_to_keep_unrotated = x[..., self.d_rotary:] x1, x2 = x_to_rotate.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_rotary/2) broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d" c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange) x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32 x_rotated = cast( torch.FloatTensor, torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype) ) return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1) def forward( self, x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head) seqlen_offset: int = 0, # each sequence is shifted by this amount - used in inference with KV cache ) -> torch.FloatTensor: if ( not self._max_seqlen or self._max_seqlen < x.shape[1] + seqlen_offset or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype or (self.training and self._cos_cached.is_inference()) ): self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype) return self._apply_rotary_emb_qkv( x, cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]), cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]), ) class SelfAttention(nn.Module): """Self-attention layer, taken from Phi2 model.""" def __init__( self, qk_scale: float | None = None, # will use 1/sqrt(d) if set to None attention_dropout: float = 0.0, ) -> None: super().__init__() self.qk_scale = qk_scale self.dropout = nn.Dropout(attention_dropout) # autocast is manually disabled to avoid `torch.einsum` using float16, which might lead to overflow @autocast("cpu", enabled=False) @autocast("cuda", enabled=False) def forward( self, qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head) causal: bool = True, key_padding_mask: torch.BoolTensor | None = None, ) -> torch.FloatTensor: batch_size, seqlen = qkv.shape[0], qkv.shape[1] q, k, v = qkv.unbind(dim=2) q = q.to(torch.float32) k = k.to(torch.float32) qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) if key_padding_mask: padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) padding_mask.masked_fill_(key_padding_mask, 0.0) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1).to(v.dtype) attention = self.dropout(attention) output = torch.einsum("bhts,bshd->bthd", attention, v) # dim: (batch_size, seqlen, n_heads, d_head) return cast(torch.FloatTensor, output) class CrossAttention(nn.Module): """Cross-attention layer, taken from Phi2 model.""" def __init__( self, qk_scale: float | None = None, # will use 1/sqrt(d) if set to None attention_dropout: float = 0.0, ) -> None: super().__init__() self.qk_scale = qk_scale self.dropout = nn.Dropout(attention_dropout) # autocast is manually disabled to avoid `torch.einsum` using float16, which might lead to overflow @autocast("cpu", enabled=False) @autocast("cuda", enabled=False) def forward( self, q: torch.FloatTensor, # dim: (batch_size, seqlen_q, n_heads, d_head) kv: torch.FloatTensor, # dim: (batch_size, seqlen_kv, 2, n_heads, d_head) causal: bool = True, key_padding_mask: torch.BoolTensor | None = None, ) -> torch.FloatTensor: batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] if kv.shape[3] != q.shape[2]: # repeat kv n_heads dim to match q n_heads kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) k, v = kv.unbind(dim=2) q = cast(torch.FloatTensor, q.to(torch.float32)) k = k.to(torch.float32) qk_scale = self.qk_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * qk_scale) if key_padding_mask: padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) padding_mask.masked_fill_(key_padding_mask, 0.0) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) causal_mask = cols > rows + seqlen_k - seqlen_q scores = scores.masked_fill(causal_mask, -10000.0) attention = torch.softmax(scores, dim=-1).to(v.dtype) attention = self.dropout(attention) output = torch.einsum("bhts,bshd->bthd", attention, v) # dim: (batch_size, seqlen_q, n_heads, d_head) return cast(torch.FloatTensor, output) class MLP(nn.Module): """Taken from Phi2 as well.""" def __init__( self, d_embedding: int, act_fn: str = "gelu_new", ) -> None: super().__init__() n_inner = 4 * d_embedding self.fc1 = nn.Linear(d_embedding, n_inner) self.act = ACT2FN[act_fn] self.fc2 = nn.Linear(n_inner, d_embedding) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x @dataclass class KVCache: """Options for model to calculate and store context during inference.""" max_seqlen: int max_batch_size: int seqlen_offset: int batch_size_offset: int kv_block_map: dict[int, torch.Tensor] = field(default_factory=dict) lengths_per_sample: torch.Tensor | None = None class MHA(nn.Module): """Multi-head attention block.""" def __init__( self, d_embedding: int, n_attn_heads: int, block_n: int, initial_cos_sin_cache_len: int, # length of cache for rotary embedding attn_pdrop: float, use_flash_rotary: bool, # use flash rotary embedding if possible use_flash_attn: bool, # use flash attention if possible use_fused_dense: bool, # use fused dense layer if possible checkpointing: bool, # torch.utils.checkpoint ) -> None: super().__init__() # rotary embedding rotary_cls = ( FlashRotaryEmbedding if use_flash_rotary and FlashRotaryEmbedding is not None else RotaryEmbedding ) self.rotary_emb = rotary_cls( # d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head d_rotary=32, # TODO: figure out why Phi2 uses this initial_cos_sin_cache_len=initial_cos_sin_cache_len, ) # self attention self_attn_cls = ( FlashSelfAttention if use_flash_attn and FlashSelfAttention is not None else SelfAttention ) self.inner_self_attn = self_attn_cls(attention_dropout=attn_pdrop) # cross attention cross_attn_cls = ( FlashCrossAttention if use_flash_attn and FlashCrossAttention is not None else CrossAttention ) self.inner_cross_attn = cross_attn_cls(attention_dropout=attn_pdrop) # MLP self.n_attn_heads = n_attn_heads self.d_head = d_embedding // n_attn_heads linear_cls = ( FusedDense if use_fused_dense and FusedDense is not None else nn.Linear ) self.Wqkv = linear_cls( d_embedding, self.d_head * (3 * self.n_attn_heads), # calculating q, k, v for all heads in block simultaneously ) self.fc_out = linear_cls(d_embedding, d_embedding) # settings self.using_flash_attn = self_attn_cls is FlashSelfAttention self.block_n = block_n self.checkpointing = checkpointing def _forward_self_attn( self, qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head) key_padding_mask: torch.BoolTensor | None, ) -> torch.FloatTensor: qkv = cast( torch.FloatTensor, torch.cat( [ self.rotary_emb(qkv[:, :, :2, :, :]), # qk qkv[:, :, 2, :, :], # v ], dim=2, ) ) if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code batch_size, seqlen = qkv.shape[0], qkv.shape[1] cu_seqlens, max_seqlen, indices = None, None, None # unpad input and retrieve `cu_seqlens` and `max_seqlen` to be used by `flash-attn` if key_padding_mask: qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask) if self.checkpointing: attn_output = torch.utils.checkpoint.checkpoint( self.inner_self_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen ) else: attn_output = self.inner_self_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device) # repad output if key_padding_mask: return pad_input(attn_output, indices, batch_size, seqlen) else: return attn_output if self.checkpointing: return torch.utils.checkpoint.checkpoint(self.inner_self_attn, qkv, key_padding_mask=key_padding_mask) else: return self.inner_self_attn(qkv, key_padding_mask=key_padding_mask) def _update_kv_cache( self, kv: torch.FloatTensor, # dim: (batch_size, seqlen, 2, n_heads, d_head) kv_cache: KVCache, block_n: int, ) -> None: if block_n not in kv_cache.kv_block_map: kv_cache.kv_block_map[block_n] = torch.empty( kv_cache.max_batch_size, kv_cache.max_seqlen, 2, kv.shape[-2], # n_heads kv.shape[-1], # d_head dtype=kv.dtype, device=kv.device, ) kv_cache.kv_block_map[block_n][ kv_cache.batch_size_offset: kv_cache.batch_size_offset + kv.shape[0], kv_cache.seqlen_offset: kv_cache.seqlen_offset + kv.shape[1], ... ] = kv def _forward_cross_attn( self, qkv: torch.FloatTensor, # dim: (batch_size, seqlen, 3, n_heads, d_head) kv_cache: KVCache, key_padding_mask: torch.BoolTensor | None, ) -> torch.FloatTensor: qk = qkv[:, :, :2, :, :] qk = self.rotary_emb( qk, seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset, ) v = cast(torch.FloatTensor, qkv[:, :, 2, :, :]) q = qk[:, :, 0, :, :] kv = torch.cat( [ qk[:, :, 1, :, :].unsqueeze(2), v.unsqueeze(2), ], dim=2, ) self._update_kv_cache(kv, kv_cache, self.block_n) causal = False # turning off causal mask for cross attention if self.using_flash_attn and unpad_input and pad_input: # not touching flash attention code batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, indices_q = ( None, None, None, None, None, ) # unpad input and retrieve `cu_seqlens` and `max_seqlen` to be used by `flash-attn` if key_padding_mask: kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) if seqlen_q == 1: key_padding_mask = cast(torch.BoolTensor, torch.ones(batch_size, 1, device=q.device)) elif seqlen_q != seqlen_k: key_padding_mask = cast(torch.BoolTensor, key_padding_mask[:, -seqlen_q:]) q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask) if self.checkpointing: attn_output = torch.utils.checkpoint.checkpoint( self.inner_cross_attn, q, kv, causal=causal, cu_seqlens=cu_seqlens_q, max_seqlen=max_seqlen_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=max_seqlen_k, ) else: attn_output = self.inner_cross_attn( q, kv, causal=causal, cu_seqlens=cu_seqlens_q, max_seqlen=max_seqlen_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=max_seqlen_k, ) if key_padding_mask: return pad_input(attn_output, indices_q, batch_size, max_seqlen_q) else: return attn_output if self.checkpointing: return torch.utils.checkpoint.checkpoint( self.inner_cross_attn, q, kv, key_padding_mask=key_padding_mask, causal=causal, ) else: return self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal) def forward( self, x: torch.FloatTensor, # dim: (batch_size, seqlen, d_embedding) kv_cache: KVCache | None = None, key_padding_mask: torch.LongTensor | torch.BoolTensor | None = None, ) -> tuple[torch.FloatTensor, torch.FloatTensor]: if key_padding_mask is not None: key_padding_mask = cast(torch.BoolTensor, key_padding_mask.bool()) # make sure it's bool and not int qkv = self.Wqkv(x) # dim: (batch_size, seqlen, 3*n_heads*d_head) qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.d_head) # dim: (batch_size, seqlen, 3, n_heads, d_head) if kv_cache is None: attn_output = self._forward_self_attn(qkv, key_padding_mask) else: attn_output = self._forward_cross_attn(qkv, kv_cache, key_padding_mask) output = rearrange(attn_output, "... h d -> ... (h d)") output = self.fc_out(output) return output class ParallelAttentionBlock(nn.Module): """From Phi2. Calculates attention and MLP in parallel. See 'Simplifying Transformer Blocks', Fig. 1 'Parallel'.""" def __init__( self, resid_pdrop: float, # a bit of a misnomer, right? layer_norm_epsilon: float, d_embedding: int, n_attn_heads: int, block_n: int, initial_cos_sin_cache_len: int, # length of cache for rotary embedding attn_pdrop: float, use_flash_rotary: bool = True, # use flash rotary embedding if possible use_flash_attn: bool = True, # use flash attention if possible use_fused_dense: bool = True, # use fused dense layer if possible checkpointing: bool = False, # torch.utils.checkpoint ) -> None: super().__init__() self.layer_norm = nn.LayerNorm(d_embedding, eps=layer_norm_epsilon) self.block_n = block_n self.multi_head_attention = MHA( d_embedding=d_embedding, n_attn_heads=n_attn_heads, block_n=block_n, initial_cos_sin_cache_len=initial_cos_sin_cache_len, attn_pdrop=attn_pdrop, use_flash_rotary=use_flash_rotary, use_flash_attn=use_flash_attn, use_fused_dense=use_fused_dense, checkpointing=checkpointing, ) self.mlp = MLP(d_embedding) self.dropout = nn.Dropout(resid_pdrop) def forward( self, x: torch.FloatTensor, # dim: (batch_size, seq_len, d_embedding) kv_cache: KVCache | None = None, key_padding_mask: torch.BoolTensor | None = None, ) -> torch.FloatTensor: residual = x x = self.layer_norm(x) # each token (dim: d_embedding) is normalized individually attn_outputs = self.multi_head_attention( x, kv_cache=kv_cache, key_padding_mask=key_padding_mask, ) mlp_outputs = self.mlp(x) return self.dropout(attn_outputs + mlp_outputs) + residual