import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import Seq2SeqLMOutput from transformers.activations import ACT2FN from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from transformers import PreTrainedModel, PretrainedConfig # Copyright (c) 2023, Tri Dao. import math import os from functools import partial import torch import torch.nn as nn from einops import rearrange, repeat from flash_attn.utils.distributed import get_dim_for_local_rank try: from flash_attn import ( flash_attn_kvpacked_func, flash_attn_qkvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache, ) except ImportError: flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None flash_attn_with_kvcache = None try: from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear except ImportError: FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None try: from flash_attn.layers.rotary import RotaryEmbedding except ImportError: RotaryEmbedding = None # From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 def get_alibi_slopes(nheads): def get_slopes_power_of_2(nheads): start = 2 ** (-(2 ** -(math.log2(nheads) - 3))) ratio = start return [start * ratio**i for i in range(nheads)] if math.log2(nheads).is_integer(): return get_slopes_power_of_2(nheads) else: closest_power_of_2 = 2 ** math.floor(math.log2(nheads)) return ( get_slopes_power_of_2(closest_power_of_2) + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2] ) class FlashSelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, causal=False, softmax_scale=None, attention_dropout=0.0, window_size=(-1, -1), alibi_slopes=None, deterministic=False, ): super().__init__() assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.window_size = window_size self.deterministic = deterministic def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). If cu_seqlens is not None and max_seqlen is not None, then qkv has shape (total, 3, H, D), where total is the sum of the sequence lengths in the batch. causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. Returns: -------- out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, else (B, S, H, D). """ assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None if self.alibi_slopes is not None: self.alibi_slopes = self.alibi_slopes.to(torch.float32) if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) return flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, window_size=self.window_size, deterministic=self.deterministic, ) else: return flash_attn_qkvpacked_func( qkv, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, window_size=self.window_size, deterministic=self.deterministic, ) class FlashCrossAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__( self, causal=False, softmax_scale=None, attention_dropout=0.0, alibi_slopes=None, window_size=(-1, -1), deterministic=False, ): super().__init__() assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) self.window_size = window_size self.deterministic = deterministic def forward( self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None, ): """Implements the multihead softmax attention. Arguments --------- q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) causal: if passed, will override self.causal cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. max_seqlen: int. Maximum sequence length in the batch of q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_k: int. Maximum sequence length in the batch of k and v. """ assert q.dtype in [torch.float16, torch.bfloat16] assert q.is_cuda and kv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None if self.alibi_slopes is not None: self.alibi_slopes = self.alibi_slopes.to(torch.float32) if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) assert cu_seqlens_k is not None assert cu_seqlens_k.dtype == torch.int32 assert max_seqlen_k is not None assert isinstance(max_seqlen_k, int) return flash_attn_varlen_kvpacked_func( q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k, self.drop.p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal, alibi_slopes=self.alibi_slopes, window_size=self.window_size, deterministic=self.deterministic, ) else: batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = kv.shape[1] assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] return flash_attn_kvpacked_func( q, kv, self.drop.p if self.training else 0.0, causal=causal, softmax_scale=self.softmax_scale, alibi_slopes=self.alibi_slopes, window_size=self.window_size, deterministic=self.deterministic, ) class SelfAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) def forward(self, qkv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, S) """ batch_size, seqlen = qkv.shape[0], qkv.shape[1] causal = self.causal if causal is None else causal q, k, v = qkv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if key_padding_mask is not None: padding_mask = torch.full( (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device ) padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: # "triu_tril_cuda_template" not implemented for 'BFloat16' # So we have to construct the mask in float causal_mask = torch.triu( torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 ) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + causal_mask.to(dtype=scores.dtype) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = self.drop(attention) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) return output class CrossAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal self.softmax_scale = softmax_scale self.drop = nn.Dropout(attention_dropout) def forward(self, q, kv, causal=None, key_padding_mask=None): """Implements the multihead softmax attention. Arguments --------- q: The tensor containing the query. (B, Sq, H, D) kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, Sk) """ batch_size, seqlen_q = q.shape[0], q.shape[1] causal = self.causal if causal is None else causal seqlen_k = kv.shape[1] assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] if kv.shape[3] != q.shape[2]: # MQA/GQA kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) k, v = kv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) if key_padding_mask is not None: padding_mask = torch.full( (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device ) padding_mask.masked_fill_(key_padding_mask, 0.0) # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: # causal mask needs to take into account the difference between seqlen_q and seqlen_k row_idx = rearrange( torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" ) col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) sk = ( seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") ) causal_mask = col_idx > row_idx + sk - seqlen_q scores = scores.masked_fill(causal_mask, -10000.0) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = self.drop(attention) output = torch.einsum("bhts,bshd->bthd", attention_drop, v) return output class LinearResidual(nn.Linear): """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input), input def _update_kv_cache(kv, inference_params, layer_idx): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" # Pre-allocate memory for key-values for inference. num_heads, head_dim = kv.shape[-2:] if layer_idx not in inference_params.key_value_memory_dict: kv_cache = torch.empty( inference_params.max_batch_size, inference_params.max_seqlen, 2, num_heads, head_dim, dtype=kv.dtype, device=kv.device, ) inference_params.key_value_memory_dict[layer_idx] = kv_cache else: kv_cache = inference_params.key_value_memory_dict[layer_idx] # Adjust key and value for inference batch_start = inference_params.batch_size_offset batch_end = batch_start + kv.shape[0] sequence_start = inference_params.seqlen_offset sequence_end = sequence_start + kv.shape[1] assert batch_end <= kv_cache.shape[0] assert sequence_end <= kv_cache.shape[1] assert kv_cache is not None kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv return kv_cache[batch_start:batch_end, :sequence_end, ...] class MHA(nn.Module): """Multi-head self-attention and cross-attention""" def __init__( self, embed_dim, num_heads, num_heads_kv=None, cross_attn=False, qkv_proj_bias=True, out_proj_bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, use_alibi=False, window_size=(-1, -1), fused_bias_fc=False, use_flash_attn=False, return_residual=False, checkpointing=False, device=None, dtype=None, ) -> None: """ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. return_residual: whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.cross_attn = cross_attn self.causal = causal self.layer_idx = layer_idx self.dwconv = dwconv self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.return_residual = return_residual self.checkpointing = checkpointing if use_alibi: assert use_flash_attn, "ALiBi code path requires flash_attn" alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) else: alibi_slopes = None if window_size != (-1, -1): assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" self.num_heads = num_heads self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads assert ( self.num_heads % self.num_heads_kv == 0 ), "num_heads must be divisible by num_heads_kv" assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) kv_dim = 2 * self.head_dim * self.num_heads_kv if self.rotary_emb_dim > 0: assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet" assert RotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = RotaryEmbedding( self.rotary_emb_dim, base=rotary_emb_base, scale_base=rotary_emb_scale_base, interleaved=rotary_emb_interleaved, device=device, ) if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed") linear_cls = nn.Linear if not fused_bias_fc else FusedDense linear_resid_cls = ( LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) ) wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls inner_attn_cls = ( partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else SelfAttention ) inner_cross_attn_cls = ( partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else CrossAttention ) if not self.cross_attn: self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) else: self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) if self.dwconv: if self.num_heads_kv == self.num_heads: self.dwconv_qkv = nn.Conv1d( qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim ) else: self.dwconv_q = nn.Conv1d( embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim ) self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) self.inner_attn = inner_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, ) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): dtype = self.out_proj.weight.dtype if dtype is None else dtype device = self.out_proj.weight.device return torch.empty( batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, dtype=dtype, device=device, ) def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" assert not self.dwconv, "Generation does not support dwconv yet" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" return _update_kv_cache(kv, inference_params, self.layer_idx) def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): """ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ assert inference_params is not None and inference_params.seqlen_offset > 0 assert self.use_flash_attn if self.rotary_emb_dim > 0: assert self.rotary_emb.scale is None, "This code path does not support xPos" self.rotary_emb._update_cos_sin_cache( inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: rotary_cos, rotary_sin = None, None batch = q.shape[0] kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], rotary_cos=rotary_cos, rotary_sin=rotary_sin, cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, alibi_slopes=alibi_slopes, ) return context def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" if ( inference_params.seqlen_offset == 0 or flash_attn_with_kvcache is None or not self.use_flash_attn ): # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) return self.inner_cross_attn(q, kv) else: batch = q.shape[0] kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) return flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, alibi_slopes=alibi_slopes, ) def forward( self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, mixer_subset=None, inference_params=None, **kwargs, ): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total is the is the sum of the sequence lengths in the batch. x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into x. Only applicable when using FlashAttention. max_seqlen: int. Maximum sequence length in the batch. key_padding_mask: boolean mask, True means to keep, False means to mask out. (batch, seqlen). Only applicable when not using FlashAttention. mixer_subset: for cross-attention only. If not None, will take a subset of x before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer. inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 """ if cu_seqlens is not None: assert max_seqlen is not None assert key_padding_mask is None assert self.use_flash_attn assert not self.dwconv assert self.rotary_emb_dim == 0 if key_padding_mask is not None: assert cu_seqlens is None assert max_seqlen is None assert not self.use_flash_attn if inference_params is not None: assert key_padding_mask is None assert cu_seqlens is None and max_seqlen is None assert not self.dwconv kwargs = ( {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} if self.use_flash_attn else {"key_padding_mask": key_padding_mask, **kwargs} ) seqlen_offset = ( 0 if inference_params is None else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) ) rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None batch, seqlen = x.shape[:2] if not self.cross_attn and self.num_heads_kv == self.num_heads: assert x_kv is None and mixer_subset is None if not self.return_residual: qkv = self.Wqkv(x) else: qkv, x = self.Wqkv(x) if self.dwconv: qkv = rearrange( self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" ).contiguous() qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): if self.rotary_emb_dim > 0: qkv = self.rotary_emb( qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: if not self.checkpointing: context = self.inner_attn(qkv, **kwargs) else: context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: context = self._update_kvcache_attention( qkv[:, :, 0], qkv[:, :, 1:], inference_params ) else: context = self._apply_rotary_update_kvcache_attention( qkv[:, :, 0], qkv[:, :, 1:], inference_params ) else: if self.cross_attn: if not self.return_residual: q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) kv = self.Wkv(x_kv if x_kv is not None else x) else: if x_kv is not None: kv, x_kv = self.Wkv(x_kv) else: kv, x = self.Wkv(x) q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) else: assert self.num_heads_kv != self.num_heads if not self.return_residual: qkv = self.Wqkv(x) else: qkv, x = self.Wqkv(x) q = qkv[..., : self.num_heads * self.head_dim] kv = qkv[..., self.num_heads * self.head_dim :] q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) if self.dwconv: q = rearrange( self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" ).contiguous() kv = rearrange( self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" ).contiguous() if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): if self.rotary_emb_dim > 0: q, kv = self.rotary_emb( q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: if not self.checkpointing: context = self.inner_cross_attn(q, kv, **kwargs) else: context = torch.utils.checkpoint.checkpoint( self.inner_cross_attn, q, kv, **kwargs ) else: context = self._update_kvcache_attention(q, kv, inference_params) else: context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) return out if not self.return_residual else (out, x) class ParallelMHA(nn.Module): """Multi-head self-attention and cross-attention""" def __init__( self, embed_dim, num_heads, process_group, num_heads_kv=None, qkv_proj_bias=True, out_proj_bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, use_alibi=False, window_size=(-1, -1), use_flash_attn=False, checkpointing=False, sequence_parallel=True, device=None, dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.embed_dim = embed_dim self.causal = causal self.layer_idx = layer_idx self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.checkpointing = checkpointing self.process_group = process_group self.world_size = process_group.size() self.local_rank = torch.distributed.get_rank(process_group) self.num_heads = num_heads assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads assert ( self.num_heads % self.num_heads_kv == 0 ), "num_heads must be divisible by num_heads_kv" self.num_heads_per_rank = get_dim_for_local_rank( self.num_heads, self.world_size, self.local_rank ) self.num_heads_kv_per_rank = get_dim_for_local_rank( self.num_heads_kv, self.world_size, self.local_rank ) self.head_dim = self.embed_dim // num_heads qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) if use_alibi: assert use_flash_attn, "ALiBi code path requires flash_attn" num_heads_local = math.ceil(self.num_heads / self.world_size) alibi_slopes = torch.tensor( get_alibi_slopes(num_heads)[ self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local ], device=device, ) else: alibi_slopes = None if window_size != (-1, -1): assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" if self.rotary_emb_dim > 0: assert RotaryEmbedding is not None, "rotary_emb is not installed" self.rotary_emb = RotaryEmbedding( self.rotary_emb_dim, base=rotary_emb_base, scale_base=rotary_emb_scale_base, interleaved=rotary_emb_interleaved, device=device, ) if ColumnParallelLinear is None or RowParallelLinear is None: raise ImportError("fused_dense is not installed") self.Wqkv = ColumnParallelLinear( embed_dim, qkv_dim, process_group, bias=qkv_proj_bias, sequence_parallel=sequence_parallel, multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), **factory_kwargs, ) inner_attn_cls = ( partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else SelfAttention ) inner_cross_attn_cls = ( partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) if use_flash_attn else CrossAttention ) self.inner_attn = inner_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout ) self.out_proj = RowParallelLinear( embed_dim, embed_dim, process_group, bias=out_proj_bias, sequence_parallel=sequence_parallel, multiple_of=self.head_dim, **factory_kwargs, ) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): dtype = self.out_proj.weight.dtype if dtype is None else dtype device = self.out_proj.weight.device return torch.empty( batch_size, max_seqlen, 2, self.num_heads_kv_per_rank, self.head_dim, dtype=dtype, device=device, ) def _update_kv_cache(self, kv, inference_params): """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" return _update_kv_cache(kv, inference_params, self.layer_idx) def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): """ Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. q: (batch_size, seqlen_q, nheads, head_dim) kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) """ assert inference_params is not None and inference_params.seqlen_offset > 0 assert self.use_flash_attn if self.rotary_emb_dim > 0: assert self.rotary_emb.scale is None, "This code path does not support xPos" self.rotary_emb._update_cos_sin_cache( inference_params.max_seqlen, device=q.device, dtype=q.dtype ) rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached else: rotary_cos, rotary_sin = None, None batch = q.shape[0] kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], rotary_cos=rotary_cos, rotary_sin=rotary_sin, cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, alibi_slopes=alibi_slopes, ) return context def _update_kvcache_attention(self, q, kv, inference_params): """Write kv to inference_params, then do attention""" if inference_params.seqlen_offset == 0 or not self.use_flash_attn: # TODO: this only uses seqlen_offset and not lengths_per_sample. kv = self._update_kv_cache(kv, inference_params) return self.inner_cross_attn(q, kv) else: batch = q.shape[0] kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] cache_seqlens = ( inference_params.lengths_per_sample[:batch] if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) context = flash_attn_with_kvcache( q, kv_cache[:, :, 0], kv_cache[:, :, 1], kv[:, :, 0], kv[:, :, 1], cache_seqlens=cache_seqlens, softmax_scale=self.inner_cross_attn.softmax_scale, causal=self.inner_cross_attn.causal, alibi_slopes=alibi_slopes, ) return context def forward(self, x, seqlen=None, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ qkv = self.Wqkv(x) if seqlen is not None: qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) seqlen_offset = ( 0 if inference_params is None else ( inference_params.lengths_per_sample if inference_params.lengths_per_sample is not None else inference_params.seqlen_offset ) ) rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None if self.num_heads_kv == self.num_heads: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): if self.rotary_emb_dim > 0: qkv = self.rotary_emb( qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: if not self.checkpointing: context = self.inner_attn(qkv, **kwargs) else: context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: context = self._update_kvcache_attention( qkv[:, :, 0], qkv[:, :, 1:], inference_params ) else: context = self._apply_rotary_update_kvcache_attention( qkv[:, :, 0], qkv[:, :, 1:], inference_params ) else: q = rearrange( qkv[..., : self.num_heads_per_rank * self.head_dim], "... (h d) -> ... h d", d=self.head_dim, ) kv = rearrange( qkv[..., self.num_heads_per_rank * self.head_dim :], "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim, ) if ( inference_params is None or inference_params.seqlen_offset == 0 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) or not self.use_flash_attn ): if self.rotary_emb_dim > 0: q, kv = self.rotary_emb( q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen ) if inference_params is None: if not self.checkpointing: context = self.inner_cross_attn(q, kv, **kwargs) else: context = torch.utils.checkpoint.checkpoint( self.inner_cross_attn, q, kv, **kwargs ) else: context = self._update_kvcache_attention(q, kv, inference_params) else: context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) context = rearrange(context, "b s h d -> b s (h d)") if seqlen is not None: context = rearrange(context, "b s d -> (b s) d") out = self.out_proj(context) return out class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, hidden_states): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) return self.weight * hidden_states.to(self.weight.dtype) class Seq2SeqConfig(PretrainedConfig): def __init__( self, vocab_size=30522, hidden_size=768, num_encoder_layers=6, num_decoder_layers=12, num_attention_heads=12, num_key_value_heads=4, intermediate_size=3072, hidden_act="silu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, bos_token_id=1, eos_token_id=2, use_cache=True, rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, **kwargs ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_encoder_layers = num_encoder_layers self.num_decoder_layers = num_decoder_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.rotary_emb_base = rotary_emb_base self.rotary_emb_scale_base = rotary_emb_scale_base self.rotary_emb_interleaved = rotary_emb_interleaved # Calculate head_dim and set rotary_emb_dim self.head_dim = self.hidden_size // self.num_attention_heads self.rotary_emb_dim = kwargs.get('rotary_emb_dim', self.head_dim // 2) # Ensure rotary_emb_dim is not larger than head_dim if self.rotary_emb_dim > self.head_dim: print(f"Warning: rotary_emb_dim ({self.rotary_emb_dim}) is larger than head_dim ({self.head_dim}). Setting rotary_emb_dim to head_dim.") self.rotary_emb_dim = self.head_dim class CustomSeq2SeqLLM(PreTrainedModel): config_class = Seq2SeqConfig base_model_prefix = "custom_seq2seq_llm" def __init__(self, config): super().__init__(config) self.config = config self.shared = nn.Embedding(config.vocab_size, config.hidden_size) self.encoder = CustomEncoder(config) self.decoder = CustomDecoder(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.loss_fn = LigerCrossEntropyLoss() self.init_weights() def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def get_input_embeddings(self): return self.shared def set_input_embeddings(self, value): self.shared = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): # Cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return { "decoder_input_ids": input_ids, "past_key_values": past, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "use_cache": use_cache, } def forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, encoder_outputs=None, past_key_values=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, position_ids=None, ): if position_ids is None and input_ids is not None: position_ids = torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if encoder_outputs is None and input_ids is not None: encoder_outputs = self.encoder( self.shared(input_ids), attention_mask=attention_mask, position_ids=position_ids, ) if decoder_input_ids is None: if labels is not None: decoder_input_ids = self._shift_right(labels) elif input_ids is not None: decoder_input_ids = input_ids else: raise ValueError("Either decoder_input_ids, labels, or input_ids must be provided.") decoder_outputs = self.decoder( self.shared(decoder_input_ids), encoder_outputs, attention_mask=decoder_attention_mask, position_ids=position_ids, ) lm_logits = self.lm_head(decoder_outputs) loss = None if labels is not None: loss = self.loss_fn(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) return Seq2SeqLMOutput( loss=loss, logits=lm_logits, encoder_last_hidden_state=encoder_outputs, decoder_hidden_states=decoder_outputs, ) def _shift_right(self, input_ids): shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 0] = self.config.pad_token_id return shifted_input_ids def save_pretrained(self, save_directory, safe_serialization=True): # Save the config self.config.save_pretrained(save_directory) # Prepare state dict state_dict = self.state_dict() # Handle shared weights if self.config.tie_word_embeddings: state_dict["lm_head.weight"] = state_dict["shared.weight"] # Convert state_dict to CPU tensors cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} if safe_serialization: # Save using safetensors safe_filepath = os.path.join(save_directory, "model.safetensors") save_file(cpu_state_dict, safe_filepath) else: # Save using PyTorch torch_filepath = os.path.join(save_directory, "pytorch_model.bin") torch.save(cpu_state_dict, torch_filepath) # @classmethod # def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # config = kwargs.pop("config", None) # state_dict = kwargs.pop("state_dict", None) # if config is None: # config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) # model = cls(config) # if state_dict is None: # # Try loading safetensors first # safe_filepath = os.path.join(pretrained_model_name_or_path, "model.safetensors") # if os.path.exists(safe_filepath): # from safetensors.torch import load_file # state_dict = load_file(safe_filepath) # else: # # Fall back to PyTorch format # torch_filepath = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") # state_dict = torch.load(torch_filepath, map_location="cpu") # # Handle shared weights # if config.tie_word_embeddings and "lm_head.weight" not in state_dict: # state_dict["lm_head.weight"] = state_dict["shared.weight"] # model.load_state_dict(state_dict) # return model class CustomEncoder(nn.Module): def __init__(self, config): super().__init__() self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None): for layer in self.layers: hidden_states = layer(hidden_states, attention_mask, position_ids) return self.layer_norm(hidden_states) class EncoderLayer(nn.Module): def __init__(self, config): super().__init__() self.self_attn = MHA(config.hidden_size, config.num_attention_heads, num_heads_kv=config.num_key_value_heads, dropout=config.attention_probs_dropout_prob, causal=False, rotary_emb_dim=config.rotary_emb_dim, rotary_emb_base=config.rotary_emb_base, rotary_emb_scale_base=config.rotary_emb_scale_base, rotary_emb_interleaved=config.rotary_emb_interleaved) self.feed_forward = LigerSwiGLUMLP(config) self.layer_norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None): normed_hidden_states = self.layer_norm1(hidden_states) attention_output = self.self_attn(normed_hidden_states, key_padding_mask=attention_mask) hidden_states = hidden_states + attention_output normed_hidden_states = self.layer_norm2(hidden_states) feed_forward_output = self.feed_forward(normed_hidden_states) hidden_states = hidden_states + feed_forward_output return hidden_states class CustomDecoder(nn.Module): def __init__(self, config): super().__init__() self.layers = nn.ModuleList([ DecoderLayer(config, use_cross_attention=self._should_use_cross_attention(i, config.num_decoder_layers)) for i in range(config.num_decoder_layers) ]) self.layer_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def _should_use_cross_attention(self, layer_idx, total_layers): return layer_idx == 0 or layer_idx == total_layers - 1 or layer_idx % 2 == 0 def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, position_ids=None): for layer in self.layers: hidden_states = layer(hidden_states, encoder_hidden_states, attention_mask, position_ids) return self.layer_norm(hidden_states) class DecoderLayer(nn.Module): def __init__(self, config, use_cross_attention=True): super().__init__() self.use_cross_attention = use_cross_attention self.self_attn = MHA(config.hidden_size, config.num_attention_heads, num_heads_kv=config.num_key_value_heads, dropout=config.attention_probs_dropout_prob, causal=True, rotary_emb_dim=config.rotary_emb_dim, rotary_emb_base=config.rotary_emb_base, rotary_emb_scale_base=config.rotary_emb_scale_base, rotary_emb_interleaved=config.rotary_emb_interleaved) if use_cross_attention: self.cross_attn = MHA(config.hidden_size, config.num_attention_heads, num_heads_kv=config.num_key_value_heads, dropout=config.attention_probs_dropout_prob, causal=False, cross_attn=True) self.layer_norm2 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = LigerSwiGLUMLP(config) self.layer_norm1 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm3 = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, position_ids=None): normed_hidden_states = self.layer_norm1(hidden_states) self_attention_output = self.self_attn(normed_hidden_states, key_padding_mask=attention_mask) hidden_states = hidden_states + self_attention_output if self.use_cross_attention: normed_hidden_states = self.layer_norm2(hidden_states) cross_attention_output = self.cross_attn(normed_hidden_states, x_kv=encoder_hidden_states, key_padding_mask=attention_mask) hidden_states = hidden_states + cross_attention_output normed_hidden_states = self.layer_norm3(hidden_states) feed_forward_output = self.feed_forward(normed_hidden_states) hidden_states = hidden_states + feed_forward_output return hidden_states