winglian commited on
Commit
6086be8
β€’
1 Parent(s): 4a92a3b

qwen2_moe support w multipack (#1455)

Browse files
examples/qwen/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Qwen
2
+
3
+ TODO
4
+
5
+ # Qwen2 MoE
6
+
7
+ βœ… multipack
8
+ βœ… qwen2_moe 4-bit QLoRA
9
+ βœ… qwen2_moe 16-bit LoRA
10
+ ❓ qwen2_moe 8-bit LoRA
examples/qwen/qwen2-moe-lora.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: Qwen/Qwen1.5-MoE-A2.7B
2
+ trust_remote_code: true
3
+
4
+ load_in_8bit: false
5
+ load_in_4bit: false
6
+ strict: false
7
+
8
+ datasets:
9
+ - path: mhenrichsen/alpaca_2k_test
10
+ type: alpaca
11
+ dataset_prepared_path:
12
+ val_set_size: 0.05
13
+ output_dir: ./out
14
+
15
+ sequence_len: 1024 # supports up to 32k
16
+ sample_packing: false
17
+ pad_to_sequence_len: false
18
+
19
+ adapter: lora
20
+ lora_model_dir:
21
+ lora_r: 32
22
+ lora_alpha: 16
23
+ lora_dropout: 0.05
24
+ lora_target_linear: true
25
+ lora_fan_in_fan_out:
26
+
27
+ wandb_project:
28
+ wandb_entity:
29
+ wandb_watch:
30
+ wandb_name:
31
+ wandb_log_model:
32
+
33
+ gradient_accumulation_steps: 4
34
+ micro_batch_size: 1
35
+ num_epochs: 4
36
+ optimizer: paged_adamw_8bit
37
+ lr_scheduler: cosine
38
+ learning_rate: 0.0002
39
+
40
+ train_on_inputs: false
41
+ group_by_length: false
42
+ bf16: auto
43
+ fp16:
44
+ tf32: true
45
+
46
+ gradient_checkpointing: true
47
+ gradient_checkpointing_kwargs:
48
+ use_reentrant: false
49
+ early_stopping_patience:
50
+ resume_from_checkpoint:
51
+ local_rank:
52
+ logging_steps: 1
53
+ xformers_attention:
54
+ flash_attention: true
55
+
56
+ warmup_steps: 10
57
+ evals_per_epoch: 4
58
+ saves_per_epoch: 1
59
+ debug:
60
+ deepspeed:
61
+ weight_decay: 0.0
62
+ fsdp:
63
+ fsdp_config:
64
+ special_tokens:
examples/qwen/qwen2-moe-qlora.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: Qwen/Qwen1.5-MoE-A2.7B
2
+ trust_remote_code: true
3
+
4
+ load_in_8bit: false
5
+ load_in_4bit: true
6
+ strict: false
7
+
8
+ datasets:
9
+ - path: mhenrichsen/alpaca_2k_test
10
+ type: alpaca
11
+ dataset_prepared_path:
12
+ val_set_size: 0.05
13
+ output_dir: ./out
14
+
15
+ sequence_len: 1024 # supports up to 32k
16
+ sample_packing: false
17
+ pad_to_sequence_len: false
18
+
19
+ adapter: lora
20
+ lora_model_dir:
21
+ lora_r: 32
22
+ lora_alpha: 16
23
+ lora_dropout: 0.05
24
+ lora_target_linear: true
25
+ lora_fan_in_fan_out:
26
+
27
+ wandb_project:
28
+ wandb_entity:
29
+ wandb_watch:
30
+ wandb_name:
31
+ wandb_log_model:
32
+
33
+ gradient_accumulation_steps: 4
34
+ micro_batch_size: 1
35
+ num_epochs: 4
36
+ optimizer: paged_adamw_8bit
37
+ lr_scheduler: cosine
38
+ learning_rate: 0.0002
39
+
40
+ train_on_inputs: false
41
+ group_by_length: false
42
+ bf16: auto
43
+ fp16:
44
+ tf32: true
45
+
46
+ gradient_checkpointing: true
47
+ gradient_checkpointing_kwargs:
48
+ use_reentrant: false
49
+ early_stopping_patience:
50
+ resume_from_checkpoint:
51
+ local_rank:
52
+ logging_steps: 1
53
+ xformers_attention:
54
+ flash_attention: true
55
+
56
+ warmup_steps: 10
57
+ evals_per_epoch: 4
58
+ saves_per_epoch: 1
59
+ debug:
60
+ deepspeed:
61
+ weight_decay: 0.0
62
+ fsdp:
63
+ fsdp_config:
64
+ special_tokens:
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
- peft==0.9.0
4
- transformers @ git+https://github.com/huggingface/transformers.git@73a73b415e36f41481369f6129cb4b62bb127a78
5
  tokenizers==0.15.0
6
  bitsandbytes==0.43.0
7
  accelerate==0.28.0
@@ -39,4 +39,4 @@ s3fs
39
  gcsfs
40
  # adlfs
41
 
42
- trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
 
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
+ peft==0.10.0
4
+ transformers @ git+https://github.com/huggingface/transformers.git@43d17c18360ac9c3d3491389328e2fe55fe8f9ce
5
  tokenizers==0.15.0
6
  bitsandbytes==0.43.0
7
  accelerate==0.28.0
 
39
  gcsfs
40
  # adlfs
41
 
42
+ trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
src/axolotl/monkeypatch/multipack.py CHANGED
@@ -12,6 +12,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
12
  SUPPORTED_MULTIPACK_MODEL_TYPES = [
13
  "mixtral",
14
  "qwen2",
 
15
  "falcon",
16
  "phi",
17
  "gemma",
@@ -31,6 +32,10 @@ def patch_for_multipack(model_type, model_name=None):
31
  transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
32
  get_unpad_data
33
  )
 
 
 
 
34
  elif model_type == "falcon":
35
  transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
36
  get_unpad_data
 
12
  SUPPORTED_MULTIPACK_MODEL_TYPES = [
13
  "mixtral",
14
  "qwen2",
15
+ "qwen2_moe",
16
  "falcon",
17
  "phi",
18
  "gemma",
 
32
  transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
33
  get_unpad_data
34
  )
35
+ elif model_type == "qwen2_moe":
36
+ transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
37
+ get_unpad_data
38
+ )
39
  elif model_type == "falcon":
40
  transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
41
  get_unpad_data
src/axolotl/utils/models.py CHANGED
@@ -456,7 +456,7 @@ def load_model(
456
  "bnb_4bit_quant_type": "nf4",
457
  "bnb_4bit_quant_storage": torch.bfloat16,
458
  }
459
- if cfg.model_config_type == "jamba" and not cfg.deepspeed:
460
  # for some reason, this causes the loss to be off by an order of magnitude
461
  # but deepspeed needs this still in bfloat16
462
  bnb_config["bnb_4bit_quant_storage"] = torch.float32
 
456
  "bnb_4bit_quant_type": "nf4",
457
  "bnb_4bit_quant_storage": torch.bfloat16,
458
  }
459
+ if not cfg.deepspeed:
460
  # for some reason, this causes the loss to be off by an order of magnitude
461
  # but deepspeed needs this still in bfloat16
462
  bnb_config["bnb_4bit_quant_storage"] = torch.float32