upstream-replit-updates

#43
by winglian - opened
Files changed (4) hide show
  1. attention.py +11 -18
  2. blocks.py +2 -2
  3. configuration_mpt.py +1 -1
  4. modeling_mpt.py +1 -0
attention.py CHANGED
@@ -5,7 +5,6 @@ from typing import Optional
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
- from packaging import version
9
  from torch import nn
10
  from .norm import LPLayerNorm
11
 
@@ -88,17 +87,9 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
88
 
89
  def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
90
  try:
91
- from .flash_attn_triton import flash_attn_func
92
  except:
93
- _installed = False
94
- if version.parse(torch.__version__) < version.parse('2.0.0'):
95
- _installed = True
96
- try:
97
- from flash_attn.flash_attn_triton import flash_attn_func
98
- except:
99
- _installed = False
100
- if not _installed:
101
- raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
102
  check_valid_inputs(query, key, value)
103
  if dropout_p:
104
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
@@ -117,7 +108,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
117
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
118
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
119
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
120
- attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
121
  output = attn_output.view(*attn_output.shape[:2], -1)
122
  return (output, None)
123
 
@@ -128,7 +119,7 @@ class MultiheadAttention(nn.Module):
128
  additive bias.
129
  """
130
 
131
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
132
  super().__init__()
133
  self.attn_impl = attn_impl
134
  self.clip_qkv = clip_qkv
@@ -150,10 +141,11 @@ class MultiheadAttention(nn.Module):
150
  self.attn_fn = flash_attn_fn
151
  elif self.attn_impl == 'triton':
152
  self.attn_fn = triton_flash_attn_fn
153
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
154
  elif self.attn_impl == 'torch':
155
  self.attn_fn = scaled_multihead_dot_product_attention
156
- if torch.cuda.is_available():
157
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
158
  else:
159
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -187,7 +179,7 @@ class MultiQueryAttention(nn.Module):
187
  additive bias.
188
  """
189
 
190
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
191
  super().__init__()
192
  self.attn_impl = attn_impl
193
  self.clip_qkv = clip_qkv
@@ -210,10 +202,11 @@ class MultiQueryAttention(nn.Module):
210
  self.attn_fn = flash_attn_fn
211
  elif self.attn_impl == 'triton':
212
  self.attn_fn = triton_flash_attn_fn
213
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
214
  elif self.attn_impl == 'torch':
215
  self.attn_fn = scaled_multihead_dot_product_attention
216
- if torch.cuda.is_available():
217
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
218
  else:
219
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
 
8
  from torch import nn
9
  from .norm import LPLayerNorm
10
 
 
87
 
88
  def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
  try:
90
+ from flash_attn import flash_attn_triton
91
  except:
92
+ raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
 
 
 
 
 
 
 
 
93
  check_valid_inputs(query, key, value)
94
  if dropout_p:
95
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
 
108
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
109
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
+ attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
112
  output = attn_output.view(*attn_output.shape[:2], -1)
113
  return (output, None)
114
 
 
119
  additive bias.
120
  """
121
 
122
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
123
  super().__init__()
124
  self.attn_impl = attn_impl
125
  self.clip_qkv = clip_qkv
 
141
  self.attn_fn = flash_attn_fn
142
  elif self.attn_impl == 'triton':
143
  self.attn_fn = triton_flash_attn_fn
144
+ if verbose:
145
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
146
  elif self.attn_impl == 'torch':
147
  self.attn_fn = scaled_multihead_dot_product_attention
148
+ if torch.cuda.is_available() and verbose:
149
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
150
  else:
151
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
179
  additive bias.
180
  """
181
 
182
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
183
  super().__init__()
184
  self.attn_impl = attn_impl
185
  self.clip_qkv = clip_qkv
 
202
  self.attn_fn = flash_attn_fn
203
  elif self.attn_impl == 'triton':
204
  self.attn_fn = triton_flash_attn_fn
205
+ if verbose:
206
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
207
  elif self.attn_impl == 'torch':
208
  self.attn_fn = scaled_multihead_dot_product_attention
209
+ if torch.cuda.is_available() and verbose:
210
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
211
  else:
212
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
blocks.py CHANGED
@@ -19,13 +19,13 @@ class MPTMLP(nn.Module):
19
 
20
  class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
 
19
 
20
  class MPTBlock(nn.Module):
21
 
22
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
+ self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
configuration_mpt.py CHANGED
@@ -2,7 +2,7 @@
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
 
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
+ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
modeling_mpt.py CHANGED
@@ -46,6 +46,7 @@ class MPTModel(MPTPreTrainedModel):
46
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
47
  self.norm_f = norm_class(config.d_model, device=config.init_device)
48
  if config.init_device != 'meta':
 
49
  self.apply(self.param_init_fn)
50
  self.is_causal = not self.prefix_lm
51
  self._attn_bias_initialized = False
 
46
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
47
  self.norm_f = norm_class(config.d_model, device=config.init_device)
48
  if config.init_device != 'meta':
49
+ print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
50
  self.apply(self.param_init_fn)
51
  self.is_causal = not self.prefix_lm
52
  self._attn_bias_initialized = False