# Adapted from https://github.com/mosaicml/llm-foundry # Classes changed: MPTBlock # SPDX-License-Identifier: Apache-2.0 """GPT Blocks used for the GPT Model.""" from typing import Dict, Optional, Tuple import torch import torch.nn as nn from .attention import ATTN_CLASS_REGISTRY from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY class MPTMLP(nn.Module): def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str] = None): super().__init__() self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) self.act = nn.GELU(approximate='none') self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) self.down_proj._is_residual = True # type: ignore def forward(self, x): return self.down_proj(self.act(self.up_proj(x))) class MPTBlock(nn.Module): def __init__( self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict = { 'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8, }, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', verbose: int = 0, device: Optional[str] = None, **kwargs): del kwargs # unused, just to capture any extra args from the config super().__init__() norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] self.norm_1 = norm_class(d_model, device=device) self.attn = attn_class( attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device, ) self.norm_2 = norm_class(d_model, device=device) self.ffn = MPTMLP( d_model=d_model, expansion_ratio=expansion_ratio, device=device, ) self.resid_attn_dropout = nn.Dropout(resid_pdrop) self.resid_ffn_dropout = nn.Dropout(resid_pdrop) def forward( self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]] = None, long_range_past_key_value:Optional[Tuple[torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attn_bias_ae: Optional[torch.Tensor] = None, attention_mask: Optional[torch.ByteTensor] = None, is_causal: bool = True, topk:int=None, needs_weights:bool=None, faiss_indexes:Tuple=None, n_layers:int=None, current_layer:int=None, mask_by_sim:bool=False, sim_threshold:float=None ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: a = self.norm_1(x) b, attn_weights, past_key_value, reshaped_idx = self.attn( a, past_key_value=past_key_value, long_range_past_key_value=long_range_past_key_value, attn_bias=attn_bias, attn_bias_ae=attn_bias_ae, attention_mask=attention_mask, is_causal=is_causal, topk=topk, needs_weights=needs_weights, faiss_indexes=faiss_indexes, n_layers=n_layers, current_layer=current_layer, mask_by_sim=mask_by_sim, sim_threshold=sim_threshold ) x = x + self.resid_attn_dropout(b) m = self.norm_2(x) n = self.ffn(m) x = x + self.resid_ffn_dropout(n) return x, attn_weights, past_key_value, reshaped_idx