Nanobit commited on
Commit
19a600a
1 Parent(s): 5e5296a

Feat: Add support for upstream FA2 (#626)

Browse files

* Feat: Add support for upstream FA2

* chore: add is_falcon_derived_model: true to examples

* chore: add config to readme for documentation

* feat: add extra model types

* fix: remove old falcon flash patch

* chore: pin transformers and accelerate

README.md CHANGED
@@ -408,6 +408,10 @@ tokenizer_legacy:
408
  # this is reported to improve training speed on some models
409
  resize_token_embeddings_to_32x:
410
 
 
 
 
 
411
  # whether you are training a 4-bit GPTQ quantized model
412
  gptq: true
413
  gptq_groupsize: 128 # group size
 
408
  # this is reported to improve training speed on some models
409
  resize_token_embeddings_to_32x:
410
 
411
+ # used to identify if the model is falcon/llama based
412
+ is_falcon_derived_model:
413
+ is_llama_derived_model:
414
+
415
  # whether you are training a 4-bit GPTQ quantized model
416
  gptq: true
417
  gptq_groupsize: 128 # group size
examples/falcon/config-7b-lora.yml CHANGED
@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
3
  trust_remote_code: true
4
  model_type: AutoModelForCausalLM
5
  tokenizer_type: AutoTokenizer
 
6
  load_in_8bit: true
7
  load_in_4bit: false
8
  gptq: false
 
3
  trust_remote_code: true
4
  model_type: AutoModelForCausalLM
5
  tokenizer_type: AutoTokenizer
6
+ is_falcon_derived_model: true
7
  load_in_8bit: true
8
  load_in_4bit: false
9
  gptq: false
examples/falcon/config-7b-qlora.yml CHANGED
@@ -6,6 +6,7 @@ base_model_config: tiiuae/falcon-7b
6
  trust_remote_code: true
7
  model_type: AutoModelForCausalLM
8
  tokenizer_type: AutoTokenizer
 
9
  load_in_8bit: false
10
  # enable 4bit for QLoRA
11
  load_in_4bit: true
 
6
  trust_remote_code: true
7
  model_type: AutoModelForCausalLM
8
  tokenizer_type: AutoTokenizer
9
+ is_falcon_derived_model: true
10
  load_in_8bit: false
11
  # enable 4bit for QLoRA
12
  load_in_4bit: true
examples/falcon/config-7b.yml CHANGED
@@ -3,6 +3,7 @@ base_model_config: tiiuae/falcon-7b
3
  trust_remote_code: true
4
  model_type: AutoModelForCausalLM
5
  tokenizer_type: AutoTokenizer
 
6
  load_in_8bit: false
7
  load_in_4bit: false
8
  gptq: false
 
3
  trust_remote_code: true
4
  model_type: AutoModelForCausalLM
5
  tokenizer_type: AutoTokenizer
6
+ is_falcon_derived_model: true
7
  load_in_8bit: false
8
  load_in_4bit: false
9
  gptq: false
requirements.txt CHANGED
@@ -4,9 +4,9 @@ torch==2.0.1
4
  auto-gptq
5
  packaging
6
  peft @ git+https://github.com/huggingface/peft.git
7
- transformers @ git+https://github.com/huggingface/transformers.git
8
  bitsandbytes>=0.41.1
9
- accelerate @ git+https://github.com/huggingface/accelerate
10
  deepspeed
11
  addict
12
  evaluate
 
4
  auto-gptq
5
  packaging
6
  peft @ git+https://github.com/huggingface/peft.git
7
+ transformers @ git+https://github.com/huggingface/transformers.git@0ac3875011d32dc85e0e83970507e3afe8f0febb
8
  bitsandbytes>=0.41.1
9
+ accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
10
  deepspeed
11
  addict
12
  evaluate
src/axolotl/monkeypatch/falcon_attn_hijack_flash.py DELETED
@@ -1,101 +0,0 @@
1
- """
2
- Flash Attention monkey patch for Falcon
3
-
4
- copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
5
- """
6
-
7
- from typing import Optional, Tuple
8
-
9
- import torch
10
- import transformers
11
- from flash_attn import flash_attn_func
12
-
13
-
14
- def forward(
15
- self,
16
- hidden_states: torch.Tensor,
17
- alibi: Optional[torch.Tensor],
18
- attention_mask: torch.Tensor, # pylint: disable=unused-argument
19
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
20
- head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
21
- use_cache: bool = False,
22
- output_attentions: bool = False, # pylint: disable=unused-argument
23
- ):
24
- fused_qkv = self.query_key_value(
25
- hidden_states
26
- ) # [batch_size, seq_length, 3 x hidden_size]
27
- num_kv_heads = (
28
- self.num_heads if self.new_decoder_architecture else self.num_kv_heads
29
- )
30
- # 3 x [batch_size, seq_length, num_heads, head_dim]
31
- (
32
- query_layer,
33
- key_layer,
34
- value_layer,
35
- ) = self._split_heads( # pylint: disable=protected-access
36
- fused_qkv
37
- )
38
-
39
- batch_size, query_length, _, _ = query_layer.shape
40
-
41
- query_layer = query_layer.transpose(1, 2).reshape(
42
- batch_size * self.num_heads, query_length, self.head_dim
43
- )
44
- key_layer = key_layer.transpose(1, 2).reshape(
45
- batch_size * num_kv_heads,
46
- query_length,
47
- self.head_dim,
48
- )
49
- value_layer = value_layer.transpose(1, 2).reshape(
50
- batch_size * num_kv_heads, query_length, self.head_dim
51
- )
52
-
53
- past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
54
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
55
-
56
- if layer_past is not None:
57
- past_key, past_value = layer_past
58
- # concatenate along seq_length dimension:
59
- # - key: [batch_size * self.num_heads, kv_length, head_dim]
60
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
61
- key_layer = torch.cat((past_key, key_layer), dim=1)
62
- value_layer = torch.cat((past_value, value_layer), dim=1)
63
-
64
- # unused
65
- # _, kv_length, _ = key_layer.shape
66
- if use_cache:
67
- present = (key_layer, value_layer)
68
- else:
69
- present = None
70
- # unused
71
- # attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
72
- query_layer_ = (
73
- query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
74
- .transpose(1, 2)
75
- .to(torch.bfloat16)
76
- )
77
- key_layer_ = (
78
- key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
79
- .transpose(1, 2)
80
- .to(torch.bfloat16)
81
- )
82
- value_layer_ = (
83
- value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
84
- .transpose(1, 2)
85
- .to(torch.bfloat16)
86
- )
87
-
88
- if alibi is not None:
89
- raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
90
-
91
- # below output will have shape (batch_size, seqlen, nheads, headdim)
92
- attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
93
- attn_output = attn_output.reshape(
94
- batch_size, query_length, self.num_heads * self.head_dim
95
- )
96
- output_tensor = self.dense(attn_output)
97
- return output_tensor, present
98
-
99
-
100
- def replace_falcon_attn_with_flash_attn():
101
- transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/utils/config.py CHANGED
@@ -86,6 +86,22 @@ def normalize_config(cfg):
86
  or (cfg.model_type and "llama" in cfg.model_type.lower())
87
  )
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
90
 
91
 
 
86
  or (cfg.model_type and "llama" in cfg.model_type.lower())
87
  )
88
 
89
+ # figure out if the model is falcon
90
+ cfg.is_falcon_derived_model = (
91
+ (
92
+ hasattr(model_config, "model_type")
93
+ and model_config.model_type
94
+ in [
95
+ "falcon",
96
+ "RefinedWebModel",
97
+ "RefinedWeb",
98
+ ]
99
+ )
100
+ or cfg.is_falcon_derived_model
101
+ or "falcon" in cfg.base_model
102
+ or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
103
+ )
104
+
105
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
106
 
107
 
src/axolotl/utils/models.py CHANGED
@@ -114,25 +114,13 @@ def load_model(
114
 
115
  replace_btlm_attn_with_flash_attn(cfg.base_model)
116
 
117
- if hasattr(model_config, "model_type") and model_config.model_type in [
118
- "falcon",
119
- "RefinedWebModel",
120
- "RefinedWeb",
121
- ]:
122
- if cfg.flash_attention:
123
- from axolotl.monkeypatch.falcon_attn_hijack_flash import (
124
- replace_falcon_attn_with_flash_attn,
125
- )
126
-
127
- replace_falcon_attn_with_flash_attn()
128
-
129
- if cfg.is_llama_derived_model and cfg.flash_attention:
130
  if cfg.device not in ["mps", "cpu"] and not inference:
131
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
132
  replace_llama_attn_with_flash_attn,
133
  )
134
 
135
- LOG.info("patching with flash attention")
136
  replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
137
  elif cfg.is_llama_derived_model and cfg.xformers_attention:
138
  from axolotl.monkeypatch.llama_attn_hijack_xformers import (
@@ -213,6 +201,10 @@ def load_model(
213
  bnb_4bit_use_double_quant=True,
214
  bnb_4bit_quant_type="nf4",
215
  )
 
 
 
 
216
  try:
217
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
218
  from transformers import LlamaForCausalLM
 
114
 
115
  replace_btlm_attn_with_flash_attn(cfg.base_model)
116
 
117
+ if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
 
 
 
 
 
 
 
 
 
 
 
 
118
  if cfg.device not in ["mps", "cpu"] and not inference:
119
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
120
  replace_llama_attn_with_flash_attn,
121
  )
122
 
123
+ LOG.info("patching with flash attention for sample packing")
124
  replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
125
  elif cfg.is_llama_derived_model and cfg.xformers_attention:
126
  from axolotl.monkeypatch.llama_attn_hijack_xformers import (
 
201
  bnb_4bit_use_double_quant=True,
202
  bnb_4bit_quant_type="nf4",
203
  )
204
+ # sample packing uses custom FA2 patch
205
+ if cfg.flash_attention and not cfg.sample_packing:
206
+ if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
207
+ model_kwargs["use_flash_attention_2"] = True
208
  try:
209
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
210
  from transformers import LlamaForCausalLM