|
"""GPT Blocks used for the GPT Model.""" |
|
from typing import Dict, Optional, Tuple |
|
import torch |
|
import torch.nn as nn |
|
from .attention import ATTN_CLASS_REGISTRY |
|
from .norm import NORM_CLASS_REGISTRY |
|
|
|
|
|
class MPTMLP(nn.Module): |
|
def __init__( |
|
self, d_model: int, expansion_ratio: int, device: Optional[str] = None |
|
): |
|
super().__init__() |
|
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) |
|
self.act = nn.GELU(approximate="none") |
|
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) |
|
self.down_proj._is_residual = True |
|
|
|
def forward(self, x): |
|
return self.down_proj(self.act(self.up_proj(x))) |
|
|
|
|
|
class MPTBlock(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
n_heads: int, |
|
expansion_ratio: int, |
|
attn_config: Dict = { |
|
"attn_type": "multihead_attention", |
|
"attn_pdrop": 0.0, |
|
"attn_impl": "triton", |
|
"qk_ln": False, |
|
"clip_qkv": None, |
|
"softmax_scale": None, |
|
"prefix_lm": False, |
|
"attn_uses_sequence_id": False, |
|
"alibi": False, |
|
"alibi_bias_max": 8, |
|
}, |
|
resid_pdrop: float = 0.0, |
|
norm_type: str = "low_precision_layernorm", |
|
device: Optional[str] = None, |
|
**kwargs |
|
): |
|
del kwargs |
|
super().__init__() |
|
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] |
|
attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] |
|
self.norm_1 = norm_class(d_model, device=device) |
|
self.attn = attn_class( |
|
attn_impl=attn_config["attn_impl"], |
|
clip_qkv=attn_config["clip_qkv"], |
|
qk_ln=attn_config["qk_ln"], |
|
softmax_scale=attn_config["softmax_scale"], |
|
attn_pdrop=attn_config["attn_pdrop"], |
|
d_model=d_model, |
|
n_heads=n_heads, |
|
device=device, |
|
) |
|
self.norm_2 = norm_class(d_model, device=device) |
|
self.ffn = MPTMLP( |
|
d_model=d_model, expansion_ratio=expansion_ratio, device=device |
|
) |
|
self.resid_attn_dropout = nn.Dropout(resid_pdrop) |
|
self.resid_ffn_dropout = nn.Dropout(resid_pdrop) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
attn_bias: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.ByteTensor] = None, |
|
is_causal: bool = True, |
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: |
|
a = self.norm_1(x) |
|
(b, _, past_key_value) = self.attn( |
|
a, |
|
past_key_value=past_key_value, |
|
attn_bias=attn_bias, |
|
attention_mask=attention_mask, |
|
is_causal=is_causal, |
|
) |
|
x = x + self.resid_attn_dropout(b) |
|
m = self.norm_2(x) |
|
n = self.ffn(m) |
|
x = x + self.resid_ffn_dropout(n) |
|
return (x, past_key_value) |
|
|