"""GPT Blocks used for the GPT Model.""" from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from .attention import ATTN_CLASS_REGISTRY from .ffn import FFN_CLASS_REGISTRY, build_ffn from .norm import NORM_CLASS_REGISTRY try: from flash_attn.bert_padding import unpad_input, pad_input except: (unpad_input, pad_input) = (None, None) attn_config_defaults: Dict = { "attn_type": "multihead_attention", "attn_pdrop": 0.0, "attn_impl": "flash", "qk_ln": True, "qk_gn": False, "clip_qkv": None, "softmax_scale": None, "prefix_lm": False, "attn_uses_sequence_id": False, "sliding_window_size": -1, "alibi": False, "alibi_bias_max": 8, "rope": False, "rope_theta": 10000, "rope_impl": "dail", "rope_dail_config": { "type": "original", "pos_idx_in_fp32": True, "xpos_scale_base": 512, }, "rope_hf_config": {"type": "no_scaling", "factor": 1.0}, } class MPTBlock(nn.Module): def __init__( self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict] = None, ffn_config: Optional[Dict] = None, resid_pdrop: float = 0.0, norm_type: str = "low_precision_layernorm", fc_type: str = "torch", device: Optional[str] = None, no_bias: bool = False, use_pad_tok_in_ffn: bool = True, **kwargs: Any ): if attn_config is None: attn_config = attn_config_defaults if ffn_config is None: ffn_config = {"ffn_type": "mptmlp"} del kwargs super().__init__() norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] assert isinstance(attn_config["attn_type"], str) attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] args_to_exclude_in_attn_class = { "attn_type", "prefix_lm", "alibi", "attn_uses_sequence_id", "alibi_bias_max", "rope", "rope_theta", "rope_impl", "rope_dail_config", "rope_hf_config", } attn_config_subset_for_attn_class = { k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class } self.norm_1 = norm_class(d_model, device=device) self.attn = attn_class( d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias ) self.norm_2 = None if not getattr(FFN_CLASS_REGISTRY[ffn_config["ffn_type"]], "_has_norm", False): self.norm_2 = norm_class(d_model, device=device) self.ffn = build_ffn( d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, **ffn_config ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) self.use_pad_tok_in_ffn = use_pad_tok_in_ffn def forward( self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, rotary_emb_w_meta_info: Optional[Dict] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, output_attentions: bool = False, alibi_slopes: Optional[torch.Tensor] = None, flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, ) -> Tuple[ torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]], ]: a = self.norm_1(x) (b, attn_weights, past_key_value) = self.attn( a, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions, alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info, ) x = x + self.resid_attn_dropout(b) m = x if self.norm_2 is not None: m = self.norm_2(x) (batch_size, seq_len) = m.size()[:2] indices = None if not self.use_pad_tok_in_ffn: assert unpad_input is not None (m, indices, _, _) = unpad_input(m, attention_mask) n = self.ffn(m) if not self.use_pad_tok_in_ffn: assert pad_input is not None n = pad_input(n, indices, batch_size, seq_len) x = x + self.resid_ffn_dropout(n) return (x, attn_weights, past_key_value)