efederici commited on
Commit
dcbb52c
1 Parent(s): c59eae9

Update blocks.py

Browse files
Files changed (1) hide show
  1. blocks.py +20 -20
blocks.py CHANGED
@@ -1,41 +1,41 @@
1
  """GPT Blocks used for the GPT Model."""
2
- from typing import Dict, Optional, Tuple
3
  import torch
4
  import torch.nn as nn
5
  from .attention import ATTN_CLASS_REGISTRY
 
6
  from .norm import NORM_CLASS_REGISTRY
7
 
8
- class MPTMLP(nn.Module):
9
-
10
- def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11
- super().__init__()
12
- self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13
- self.act = nn.GELU(approximate='none')
14
- self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15
- self.down_proj._is_residual = True
16
-
17
- def forward(self, x):
18
- return self.down_proj(self.act(self.up_proj(x)))
19
-
20
  class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
 
 
 
 
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
 
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
 
 
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29
- self.norm_2 = norm_class(d_model, device=device)
30
- self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
 
 
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
 
34
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
- m = self.norm_2(x)
 
 
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
  return (x, attn_weights, past_key_value)
 
1
  """GPT Blocks used for the GPT Model."""
2
+ from typing import Any, Dict, Optional, Tuple
3
  import torch
4
  import torch.nn as nn
5
  from .attention import ATTN_CLASS_REGISTRY
6
+ from .ffn import FFN_CLASS_REGISTRY, build_ffn
7
  from .norm import NORM_CLASS_REGISTRY
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class MPTBlock(nn.Module):
10
 
11
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Optional[Dict]=None, ffn_config: Optional[Dict]=None, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', fc_type: str='torch', device: Optional[str]=None, no_bias: bool=False, **kwargs: Any):
12
+ if attn_config is None:
13
+ attn_config = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
14
+ if ffn_config is None:
15
+ ffn_config = {'ffn_type': 'mptmlp'}
16
  del kwargs
17
  super().__init__()
18
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
19
+ assert isinstance(attn_config['attn_type'], str)
20
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
21
+ args_to_exclude_in_attn_class = {'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max'}
22
+ attn_config_subset_for_attn_class = {k: v for (k, v) in attn_config.items() if k not in args_to_exclude_in_attn_class}
23
  self.norm_1 = norm_class(d_model, device=device)
24
+ self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, device=device, **attn_config_subset_for_attn_class, bias=not no_bias)
25
+ self.norm_2 = None
26
+ if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm', False):
27
+ self.norm_2 = norm_class(d_model, device=device)
28
+ self.ffn = build_ffn(d_model=d_model, expansion_ratio=expansion_ratio, device=device, bias=not no_bias, **ffn_config)
29
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
30
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
31
 
32
+ def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, output_attentions: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
33
  a = self.norm_1(x)
34
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, needs_weights=output_attentions)
35
  x = x + self.resid_attn_dropout(b)
36
+ m = x
37
+ if self.norm_2 is not None:
38
+ m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
  return (x, attn_weights, past_key_value)