Pklett's picture
upload custom code
c2d160f
raw
history blame
No virus
4.29 kB
# Adapted from https://github.com/mosaicml/llm-foundry
# Classes changed: MPTBlock
# SPDX-License-Identifier: Apache-2.0
"""GPT Blocks used for the GPT Model."""
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from .attention import ATTN_CLASS_REGISTRY
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
class MPTMLP(nn.Module):
def __init__(self,
d_model: int,
expansion_ratio: int,
device: Optional[str] = None):
super().__init__()
self.up_proj = nn.Linear(d_model,
expansion_ratio * d_model,
device=device)
self.act = nn.GELU(approximate='none')
self.down_proj = nn.Linear(expansion_ratio * d_model,
d_model,
device=device)
self.down_proj._is_residual = True # type: ignore
def forward(self, x):
return self.down_proj(self.act(self.up_proj(x)))
class MPTBlock(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
expansion_ratio: int,
attn_config: Dict = {
'attn_type': 'multihead_attention',
'attn_pdrop': 0.0,
'attn_impl': 'triton',
'qk_ln': False,
'clip_qkv': None,
'softmax_scale': None,
'prefix_lm': False,
'attn_uses_sequence_id': False,
'alibi': False,
'alibi_bias_max': 8,
},
resid_pdrop: float = 0.0,
norm_type: str = 'low_precision_layernorm',
verbose: int = 0,
device: Optional[str] = None,
**kwargs):
del kwargs # unused, just to capture any extra args from the config
super().__init__()
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
self.norm_1 = norm_class(d_model, device=device)
self.attn = attn_class(
attn_impl=attn_config['attn_impl'],
clip_qkv=attn_config['clip_qkv'],
qk_ln=attn_config['qk_ln'],
softmax_scale=attn_config['softmax_scale'],
attn_pdrop=attn_config['attn_pdrop'],
d_model=d_model,
n_heads=n_heads,
verbose=verbose,
device=device,
)
self.norm_2 = norm_class(d_model, device=device)
self.ffn = MPTMLP(
d_model=d_model,
expansion_ratio=expansion_ratio,
device=device,
)
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
def forward(
self,
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
long_range_past_key_value:Optional[Tuple[torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
attn_bias_ae: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
topk:int=None,
needs_weights:bool=None,
faiss_indexes:Tuple=None,
n_layers:int=None,
current_layer:int=None,
mask_by_sim:bool=False,
sim_threshold:float=None
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
a = self.norm_1(x)
b, attn_weights, past_key_value, reshaped_idx = self.attn(
a,
past_key_value=past_key_value,
long_range_past_key_value=long_range_past_key_value,
attn_bias=attn_bias,
attn_bias_ae=attn_bias_ae,
attention_mask=attention_mask,
is_causal=is_causal,
topk=topk,
needs_weights=needs_weights,
faiss_indexes=faiss_indexes,
n_layers=n_layers,
current_layer=current_layer,
mask_by_sim=mask_by_sim,
sim_threshold=sim_threshold
)
x = x + self.resid_attn_dropout(b)
m = self.norm_2(x)
n = self.ffn(m)
x = x + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value, reshaped_idx