Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference
mpt-7b-chat / blocks.py
sam-mosaic's picture
Upload folder using huggingface_hub
36b0251
raw history blame
No virus
2.49 kB
"""GPT Blocks used for the GPT Model."""
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from .attention import ATTN_CLASS_REGISTRY
from .norm import NORM_CLASS_REGISTRY
class MPTMLP(nn.Module):
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
super().__init__()
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
self.act = nn.GELU(approximate='none')
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
self.down_proj._is_residual = True
def forward(self, x):
return self.down_proj(self.act(self.up_proj(x)))
class MPTBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
del kwargs
super().__init__()
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
self.norm_2 = norm_class(d_model, device=device)
self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
a = self.norm_1(x)
(b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
x = x + self.resid_attn_dropout(b)
m = self.norm_2(x)
n = self.ffn(m)
x = x + self.resid_ffn_dropout(n)
return (x, past_key_value)