Update blocks.py
Browse files
blocks.py
CHANGED
@@ -1,41 +1,41 @@
|
|
1 |
"""GPT Blocks used for the GPT Model."""
|
2 |
-
from typing import Dict, Optional, Tuple
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from .attention import ATTN_CLASS_REGISTRY
|
|
|
6 |
from .norm import NORM_CLASS_REGISTRY
|
7 |
|
8 |
-
class MPTMLP(nn.Module):
|
9 |
-
|
10 |
-
def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
|
11 |
-
super().__init__()
|
12 |
-
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
|
13 |
-
self.act = nn.GELU(approximate='none')
|
14 |
-
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
|
15 |
-
self.down_proj._is_residual = True
|
16 |
-
|
17 |
-
def forward(self, x):
|
18 |
-
return self.down_proj(self.act(self.up_proj(x)))
|
19 |
-
|
20 |
class MPTBlock(nn.Module):
|
21 |
|
22 |
-
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict=
|
|
|
|
|
|
|
|
|
23 |
del kwargs
|
24 |
super().__init__()
|
25 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
|
|
26 |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
|
|
|
|
27 |
self.norm_1 = norm_class(d_model, device=device)
|
28 |
-
self.attn = attn_class(
|
29 |
-
self.norm_2 =
|
30 |
-
|
|
|
|
|
31 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
32 |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
33 |
|
34 |
-
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
35 |
a = self.norm_1(x)
|
36 |
-
(b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
|
37 |
x = x + self.resid_attn_dropout(b)
|
38 |
-
m =
|
|
|
|
|
39 |
n = self.ffn(m)
|
40 |
x = x + self.resid_ffn_dropout(n)
|
41 |
return (x, attn_weights, past_key_value)
|
|
|
1 |
"""GPT Blocks used for the GPT Model."""
|
2 |
+
from typing import Any, Dict, Optional, Tuple
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from .attention import ATTN_CLASS_REGISTRY
|
6 |
+
from .ffn import FFN_CLASS_REGISTRY, build_ffn
|
7 |
from .norm import NORM_CLASS_REGISTRY
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
class MPTBlock(nn.Module):
|
10 |
|
11 |
+
def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict]=None, ffn_config: Optional[Dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, no_bias: bool=False, **kwargs: Any):
|
12 |
+
if attn_config is None:
|
13 |
+
attn_config = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
|
14 |
+
if ffn_config is None:
|
15 |
+
ffn_config = {'ffn_type': 'mptmlp'}
|
16 |
del kwargs
|
17 |
super().__init__()
|
18 |
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
|
19 |
+
assert isinstance(attn_config['attn_type'], str)
|
20 |
attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
|
21 |
+
args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max'}
|
22 |
+
attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
|
23 |
self.norm_1 = norm_class(d_model, device=device)
|
24 |
+
self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias)
|
25 |
+
self.norm_2 = None
|
26 |
+
if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
|
27 |
+
self.norm_2 = norm_class(d_model, device=device)
|
28 |
+
self.ffn = build_ffn(d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, **ffn_config)
|
29 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
30 |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
31 |
|
32 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
33 |
a = self.norm_1(x)
|
34 |
+
(b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions)
|
35 |
x = x + self.resid_attn_dropout(b)
|
36 |
+
m = x
|
37 |
+
if self.norm_2 is not None:
|
38 |
+
m = self.norm_2(x)
|
39 |
n = self.ffn(m)
|
40 |
x = x + self.resid_ffn_dropout(n)
|
41 |
return (x, attn_weights, past_key_value)
|